diff --git a/lightllm/models/llama/triton_kernel/silu_and_mul.py b/lightllm/models/llama/triton_kernel/silu_and_mul.py index 42a54794f..5bd2c295e 100644 --- a/lightllm/models/llama/triton_kernel/silu_and_mul.py +++ b/lightllm/models/llama/triton_kernel/silu_and_mul.py @@ -17,8 +17,8 @@ def _silu_and_mul_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - stride_input_m = stride_input_m.to(tl.int64) - stride_output_m = stride_output_m.to(tl.int64) + stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) + stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) tid = tl.program_id(0) input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M) @@ -53,7 +53,7 @@ def _silu_and_mul_kernel( ) -def silu_and_mul_fwd(input, output): +def silu_and_mul_fwd(input: torch.Tensor, output): stride_input_m = input.stride(0) stride_input_n = input.stride(1) stride_output_m = output.stride(0) @@ -88,13 +88,13 @@ def torch_silu_and_mul(input: torch.Tensor): def test_silu_and_mul(M, N, dtype, device="cuda"): # create data X = torch.randn((M, N), dtype=dtype, device=device) - + y_tri = torch.empty((M, N // 2), dtype=dtype, device=device) # run - y_tri = silu_and_mul_fwd(X) + silu_and_mul_fwd(X, y_tri) y_ref = torch_silu_and_mul(X) # compare print("type:", y_tri.dtype, y_ref.dtype) print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - assert torch.allclose(y_tri, y_ref, atol=1e-6, rtol=0) + assert torch.allclose(y_tri, y_ref, atol=1e-5, rtol=0) return