Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int64 and uint64 types for all algorithms and tests #1

Merged
merged 1 commit into from
Dec 11, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/all_gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,12 @@ public:
case ncclDouble:
return ncclAllGatherWithType<double>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
case ncclInt64:
return ncclAllGatherWithType<long long>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
case ncclUint64:
return ncclAllGatherWithType<unsigned long long>(sendbuff, recvbuff, count, comm,
numUnroll, stream);
}
return ncclInvalidType;
}
Expand Down
2 changes: 2 additions & 0 deletions src/all_gather_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ int main(int argc, char* argv[]) {
#endif
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);

printf("\n");

Expand Down
6 changes: 6 additions & 0 deletions src/all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ public:
case ncclDouble:
return ncclAllReduceWithType<double>(sendbuff, recvbuff, count, op,
comm, stream);
case ncclInt64:
return ncclAllReduceWithType<long long>(sendbuff, recvbuff, count, op,
comm, stream);
case ncclUint64:
return ncclAllReduceWithType<unsigned long long int>(sendbuff, recvbuff, count, op,
comm, stream);
}

return ncclInvalidType;
Expand Down
2 changes: 2 additions & 0 deletions src/all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ int main(int argc, char* argv[]) {
#endif
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);

printf("\n");

Expand Down
4 changes: 4 additions & 0 deletions src/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ public:
return ncclBcastWithType<float>(buff, count, root, comm, numUnroll, stream);
case ncclDouble:
return ncclBcastWithType<double>(buff, count, root, comm, numUnroll, stream);
case ncclInt64:
return ncclBcastWithType<long long>(buff, count, root, comm, numUnroll, stream);
case ncclUint64:
return ncclBcastWithType<unsigned long long>(buff, count, root, comm, numUnroll, stream);
}
return ncclInvalidType;
}
Expand Down
2 changes: 2 additions & 0 deletions src/broadcast_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ int main(int argc, char* argv[]) {
#endif
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);

printf("\n");

Expand Down
20 changes: 20 additions & 0 deletions src/common_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,26 @@ struct MULTI<FUNC, double> {
}
};

template<class FUNC>
struct MULTI<FUNC, unsigned long long> {
static_assert(sizeof(PackType) == sizeof(unsigned long long),
"PackType must be the same size as unsigned long long.");
__device__ PackType operator()(const PackType x, const PackType y) const {
unsigned long long rv = FUNC()(x, y);
return rv;
}
};

template<class FUNC>
struct MULTI<FUNC, long long> {
static_assert(sizeof(PackType) == sizeof(long long),
"PackType must be the same size as long long.");
__device__ PackType operator()(const PackType x, const PackType y) const {
long long rv = FUNC()((long long)x, (long long)y);
return rv;
}
};

template<typename T, bool FETCHTWO>
__device__ inline void FetchOneOrTwo64b(PackType& s0,
const volatile T * __restrict__ const src0, PackType& s1,
Expand Down
4 changes: 3 additions & 1 deletion src/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ typedef enum { ncclChar = 0,
#endif
ncclFloat = 3,
ncclDouble = 4,
nccl_NUM_TYPES = 5 } ncclDataType_t;
ncclInt64 = 5,
ncclUint64 = 6,
nccl_NUM_TYPES = 7 } ncclDataType_t;

/* Reduces data arrays of length count in sendbuff into recvbuf using op operation.
* recvbuf may be NULL on all calls except for root device.
Expand Down
4 changes: 4 additions & 0 deletions src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ public:
return ncclReduceWithType<float>(sendbuff, recvbuff, count, op, root, comm, stream);
case ncclDouble:
return ncclReduceWithType<double>(sendbuff, recvbuff, count, op, root, comm, stream);
case ncclInt64:
return ncclReduceWithType<long long>(sendbuff, recvbuff, count, op, root, comm, stream);
case ncclUint64:
return ncclReduceWithType<unsigned long long>(sendbuff, recvbuff, count, op, root, comm, stream);
}
return ncclInvalidType;
}
Expand Down
6 changes: 6 additions & 0 deletions src/reduce_scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,12 @@ public:
case ncclDouble:
return ncclReduceScatterWithType<double>(sendbuff, recvbuff, recvcount,
op, comm, stream);
case ncclInt64:
return ncclReduceScatterWithType<long long>(sendbuff, recvbuff, recvcount,
op, comm, stream);
case ncclUint64:
return ncclReduceScatterWithType<unsigned long long>(sendbuff, recvbuff, recvcount,
op, comm, stream);
}
return ncclInvalidType;
}
Expand Down
2 changes: 2 additions & 0 deletions src/reduce_scatter_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ int main(int argc, char* argv[]) {
#endif
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);

printf("\n");

Expand Down
2 changes: 2 additions & 0 deletions src/reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ int main(int argc, char* argv[]) {
#endif
RunTests<float>(N / sizeof(float), ncclFloat, comms, dList);
RunTests<double>(N / sizeof(double), ncclDouble, comms, dList);
RunTests<long long>(N / sizeof(long long), ncclInt64, comms, dList);
RunTests<unsigned long long>(N / sizeof(unsigned long long), ncclUint64, comms, dList);

printf("\n");

Expand Down
26 changes: 26 additions & 0 deletions src/test_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ void GenerateRandom<double>(curandGenerator_t generator, double * const dest,
CURAND_CHK(curandGenerateUniformDouble(generator, dest, N));
}

template<>
void GenerateRandom<unsigned long long>(curandGenerator_t generator, unsigned long long * const dest,
const int N) {
CURAND_CHK(curandGenerateLongLong(generator, dest, N));
}


template<typename T>
void Randomize(T* const dest, const int N, const int randomSeed) {
Expand All @@ -100,6 +106,24 @@ void Randomize(T* const dest, const int N, const int randomSeed) {
CUDACHECK(cudaDeviceSynchronize());
}

template<>
void Randomize(unsigned long long* const dest, const int N, const int randomSeed) {
curandGenerator_t gen;
CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64));
GenerateRandom<unsigned long long>(gen, dest, N);
CURAND_CHK(curandDestroyGenerator(gen));
CUDACHECK(cudaDeviceSynchronize());
}

template<>
void Randomize(long long* const dest, const int N, const int randomSeed) {
curandGenerator_t gen;
CURAND_CHK(curandCreateGenerator(&gen, CURAND_RNG_QUASI_SOBOL64));
GenerateRandom<unsigned long long>(gen, (unsigned long long *)dest, N);
CURAND_CHK(curandDestroyGenerator(gen));
CUDACHECK(cudaDeviceSynchronize());
}

#ifdef CUDA_HAS_HALF
__global__ void halve(const float * src, half* dest, int N) {
for(int tid = threadIdx.x + blockIdx.x*blockDim.x;
Expand Down Expand Up @@ -268,6 +292,8 @@ std::string TypeName(const ncclDataType_t type) {
#endif
case ncclFloat: return "float";
case ncclDouble: return "double";
case ncclInt64: return "int64";
case ncclUint64: return "uint64";
default: return "unknown";
}
}
Expand Down