-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda : optimize argmax #10441
cuda : optimize argmax #10441
Conversation
ggml-ci
ggml-ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, after I had already written the code and especially after #10318 I've been thinking that I set the wrong priorities for this kernel. I think the only comment of mine that needs to be addressed is the one about undefined behavior, the rest are only suggestions.
if (val > maxval) { | ||
maxval = val; | ||
argmax = col; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In retrospect it probably makes more sense to do it like this; conditional statements are problematic for code optimization since they prevent the compiler from reordering instructions but there isn't much to do in one loop iteration anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't measure a meaningful difference in performance and this should be easier to understand and maintain. Maybe in some hardware it would make a difference? I also would expect the compiler to be able to optimize simple conditionals like this, but that may be expecting too much.
ggml/src/ggml-cuda/argmax.cu
Outdated
if (warp_id == 0 && lane_id < n_warps) { | ||
maxval = shared_maxval[lane_id]; | ||
argmax = shared_argmax[lane_id]; | ||
const unsigned int mask = (1u << n_warps) - 1u; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably faster to just have all threads participate in the shuffle unconditionally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for doing this is that if there are less than 32 warps, then some values will not be written to the shared memory, so they should not be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion would be to just have the first warp clear the memory and then do a __syncthreads
before reading again.
if (warp_id == 0 && lane_id == 0) { | ||
dst[row] = argmax; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My experience is that conditional returns/continues are faster than conditional writes but it probably doesn't matter much.
|
||
const dim3 blocks_dim(WARP_SIZE, 1, 1); | ||
const int64_t num_blocks = nrows; | ||
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be efficient for 32 <= ne00 <= 1024 and ne00 >> 1024 but inefficient for 1024 < ne00 <= 4096. And in general, if you have a variable block size you should make it a template parameter.
ggml/src/ggml-cuda/argmax.cu
Outdated
for (int offset = 16; offset > 0; offset >>= 1) { | ||
const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE); | ||
const int col = __shfl_xor_sync(mask, argmax, offset, WARP_SIZE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CUDA documentation says:
Threads may only read data from another thread which is actively participating in the __shfl_sync() command. If the target thread is inactive, the retrieved value is undefined.
It doesn't explicitly mention __shfl_xor_sync
but I suspect that the same hardware is used and that thus the same limitations apply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the PTX documentation the behavior is definitely undefined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what's the undefined behavior here, can you elaborate? The mask is set such as only the threads participating in the sync are used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PTX documentations reads:
Note that results are undefined if a thread sources a register from an inactive thread or a thread that is not in membermask.
The problem is that even if you limit the participating threads via the mask they are still retrieving data from threads outside the mask. You would have to dynamically change the values of offset
and in the most general case where n_warps
is not a power of 2 you would need to use instructions other than __shfl_xor_sync
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thanks. It should be fixed now.
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to also check the case with ne01 and ne00 flipped where whether or not the writes are coalesced makes a comparatively larger difference. But that would be the case with a very large batch size and few classes and especially with language models that have large vocabulary sizes I think it's not an important use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean test for correctness or performance? These cases are the ones used in eval mode only.
I also tested the performance with [512,32000], and it drops to 480GB/s (compared to 730GB/s with [32000,512]). There are surely more optimization opportunities, but I don't think it is worth spending more time on this at moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only meant performance. I wrote the code on master in the context of the ggml MNIST example with an input shape of {10, 1000, 1, 1}
. In principle, if you have a low number of classes but a large number of datapoints the number of writes should become significant and it would make sense to try and coalesce them (but with the code on master there are likely also issues with tail effects because the number of CUDA blocks is reduced by a factor of 32). In the first place, I should have written code with a use case like 256000, 128, 1, 1
in mind since that is going to be relevant for llama.cpp.
Co-authored-by: Johannes Gäßler <[email protected]>
static __global__ void argmax_f32( | ||
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) { | ||
float maxval = -FLT_MAX; | ||
int argmax = -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the code again, I think either 64 bit should be used for the ne00
dimension or there should be an assert that 32 bit is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output is int32, so it would definitely not work with ne00
larger than INT_MAX
. In that case it might make more sense to add the assert to ggml_argmax
instead. Other arg* functions will have the same issue.
* cuda : optimize argmax * remove unused parameter ggml-ci * fixup : use full warps ggml-ci * Apply suggestions from code review Co-authored-by: Johannes Gäßler <[email protected]> * fix ub * ggml : check ne00 <= INT32_MAX in argmax and argsort --------- Co-authored-by: Johannes Gäßler <[email protected]>
* cuda : optimize argmax * remove unused parameter ggml-ci * fixup : use full warps ggml-ci * Apply suggestions from code review Co-authored-by: Johannes Gäßler <[email protected]> * fix ub * ggml : check ne00 <= INT32_MAX in argmax and argsort --------- Co-authored-by: Johannes Gäßler <[email protected]>
I was curious about the CUDA implementation to see if it could be used as a reference for the Metal implementation and figured it could be optimized. Processes one row per group, uses multiple warps if the row size is big enough.
Also renamed loop parameter of the warp shuffles to
offset
, since that should be more accurate.