Skip to content

Commit

Permalink
gpu:cuda: Fix matmul parameters for inner_product usages
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Dec 19, 2024
1 parent c3b5d23 commit 8eddd9b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/gpu/nvidia/cudnn_matmul_executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,12 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
memory_tracking::names::key_matmul_dst_in_acc_dt)
: xpu::sycl::interop_memory_arg_t<
::sycl::access::mode::read_write>();
auto arg_block_a_scratch = params->source_size_ != 0
auto arg_block_a_scratch = params->weight_size_ != 0
? CTX_SCRATCH_SYCL_MEMORY(
memory_tracking::names::key_gemm_blocked_a)
: xpu::sycl::interop_memory_arg_t<
::sycl::access::mode::read_write>();
auto arg_block_b_scratch = params->weight_size_ != 0
auto arg_block_b_scratch = params->source_size_ != 0
? CTX_SCRATCH_SYCL_MEMORY(
memory_tracking::names::key_gemm_blocked_b)
: xpu::sycl::interop_memory_arg_t<
Expand Down Expand Up @@ -457,10 +457,10 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
matmul_params->reorder_scratch_size_, cuda_stream->queue());

uint8_t *block_a_scratch_ptr
= alloc_ptr(matmul_params->source_size_, cuda_stream->queue());
= alloc_ptr(matmul_params->weight_size_, cuda_stream->queue());

uint8_t *block_b_scratch_ptr
= alloc_ptr(matmul_params->weight_size_, cuda_stream->queue());
= alloc_ptr(matmul_params->source_size_, cuda_stream->queue());

uint8_t *block_c_scratch_ptr
= alloc_ptr(matmul_params->dest_size_, cuda_stream->queue());
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/nvidia/cudnn_matmul_lt_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ struct cudnn_matmul_lt_impl_t {
}
if (!params->w_blocked_) {
transform_matrix(lt_handle, params, a_layout, a,
blocked_a_layout, block_a_scratch, !params->trans_a_,
blocked_a_layout, block_a_scratch, params->trans_a_,
streamId);
a = block_a_scratch;
}
Expand Down

0 comments on commit 8eddd9b

Please sign in to comment.