Skip to content

Commit

Permalink
vulkan: further optimize mul_mat_vec using larger loads (ggerganov#10387
Browse files Browse the repository at this point in the history
)

* vulkan: Use pipeline_robustness to disable robustness in mul_mat_vec.

Add some early returns for nonexistent rows in mul_mat_vec shaders. These
can only be hit when dispatching a 2D grid of workgroups. Fix the logic
for the 2D grid of workgroups to round up.

Enable the pipeline robustness extension if it's available, and use it to
disable robustness for these pipelines. The instructions to do the bounds
checking contend for the same ALU resources as the bit twiddling dequant
instructions.

* vulkan: Add GLSL structure aliases for quant types to allow larger loads

In Vulkan it's not possible to cast pointer types, so instead you have to
declare an aliased binding for the memory with a different type. This
commit adds aliases for the quant formats using 16b ints, and in a few
places where the struct size is a multiple of 4 also using 32b ints.
Currently only q4_k's aliases are used, but others will be used in
subsequent commits.

* vulkan: use larger loads in q5_k and q6_k shaders.

Similar to the optimization I did in q4_k recently, this vectorizes some loads
and reduces the number of bit twiddling instructions.

* vulkan: use larger K step per iteration in mul_mat_vec.

Add vec4 dequantization functions, and use them to do K=8 per iteration in
mul_mat_vec. This uses 16b loads for the quant values and 128b loads for B
which helps reduce the load on the memory system.

The K_PER_ITER==2 logic is still there, just for F16/F32, and really only
because they support unaligned sizes.

Tweak the num_iters/unrolling logic to be simpler and catch a couple missed
unrolling opportunities.
  • Loading branch information
jeffbolznv authored Nov 20, 2024
1 parent ad21c9e commit 1bacb9f
Show file tree
Hide file tree
Showing 11 changed files with 459 additions and 149 deletions.
101 changes: 63 additions & 38 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif

#include "types.comp"

#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
Expand All @@ -20,6 +29,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d;
}
#endif

#if defined(DATA_A_Q4_1)
Expand All @@ -29,6 +43,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(vui & 0xF, vui >> 4) * d + m;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const float m = float(data_a_packed16[a_offset + ib].m);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
}
#endif

#if defined(DATA_A_Q5_0)
Expand All @@ -39,6 +59,14 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d;
}
#endif

#if defined(DATA_A_Q5_1)
Expand All @@ -50,13 +78,28 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const float m = float(data_a_packed16[a_offset + ib].m);
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
}
#endif

#if defined(DATA_A_Q8_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a[a_offset + ib].d);
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d;
}
#endif

#if defined(DATA_A_IQ4_NL)
Expand All @@ -65,4 +108,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d;
}
#endif
82 changes: 72 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types : require

#include "mul_mat_vec_base.comp"

Expand All @@ -12,16 +12,48 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;

#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#endif


uint a_offset, b_offset, d_offset, y_offset;

shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];

void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
const uint col = i*BLOCK_SIZE + 2*tid;
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index

#if K_PER_ITER == 8
#if QUANT_R == 2
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
#else
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
#endif
#else
// Check if the second of the pair of elements is OOB, and don't fetch B or
// accumulate it. We still fetch a pair of elements for A, which is fine for
// quantized formats since they'll be within the same block. We should
Expand All @@ -34,16 +66,32 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
if (!OOB) {
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
}
#endif
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index

#if K_PER_ITER == 8
const vec4 v = dequantize4(ib, iqs, a_offset);
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);

// matrix multiplication
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
#else
const vec2 v = dequantize(ib, iqs, a_offset);

// matrix multiplication
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
if (!OOB) {
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
}
#endif
}
}

Expand All @@ -61,22 +109,33 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
temp[i] = FLOAT_TYPE(0);
}

const int unroll_count = 8;

const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);

uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i, false);
i += 2;
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i, true);
i += 2;
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
i++;
}

// sum up partial sums and write back result
Expand Down Expand Up @@ -106,6 +165,9 @@ void main() {
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};

layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

if (row >= p.stride_d) {
return;
}

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

if (row >= p.stride_d) {
return;
}

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

Expand Down
34 changes: 9 additions & 25 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,14 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;

shared FLOAT_TYPE tmp[32];

// Declare aliased versions of A and B bindings that can use 16b/32b loads for
// the quantized values, and vec4 loads for B.
struct block_q4_K_u32
{
f16vec2 d;
uint32_t scales[3*QUANT_K/64/4];
uint32_t qs[QUANT_K/2/4];
};

struct block_q4_K_u16
{
f16vec2 d;
uint16_t scales[3*QUANT_K/64/2];
uint16_t qs[QUANT_K/2/2];
};

layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};

// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

if (row >= p.stride_d) {
return;
}

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

Expand Down Expand Up @@ -64,9 +48,9 @@ void main() {
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
Expand All @@ -80,8 +64,8 @@ void main() {
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));

uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];

uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
Expand Down
Loading

0 comments on commit 1bacb9f

Please sign in to comment.