Skip to content

Commit

Permalink
doc: graph: sdpa: fix wordings
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoLv committed Dec 20, 2024
1 parent eff6df6 commit e5933c4
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions doc/graph/complex_fusion/sdpa.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Scaled Dot-Product Attention (SDPA) {#dev_guide_graph_sdpa}
===========================================================

## Background
## Overview

Scaled Dot-Product Attention (SDPA) was introduced in [1] as the core operation
Scaled Dot-Product Attention (SDPA) is introduced in [1] as the core operation
of Transformer block which now becomes the backbone of many language models and
generative models (BERT, Stable Diffusion, GPT, etc.).

Expand All @@ -30,9 +30,9 @@ SDPA graph, getting partition from the graph, and optimizing the kernels
underneath. In general, an SDPA pattern is defined as a directional acyclic
graph (DAG) using oneDNN Graph API.

### Floating point SDPA
### Floating-point SDPA

oneDNN defines floating point (f32, bf16, or f16) SDPA as follows. The blue
oneDNN defines floating-point (f32, bf16, or f16) SDPA as follows. The blue
nodes are required when defining an SDPA pattern while the brown parts are
optional.

Expand Down Expand Up @@ -74,12 +74,12 @@ optional.
![SDPA-Reorder](images/sdpa-reorder.png)


## Data types
## Data Types

oneDNN supports the floating point SDPA pattern with data types f32, bf16, and
f16. oneDNN users can specify the data type via the input and output logical
tensors' data type fields for each operation. oneDNN does not support mixing
different floating data types in a floating point SDPA pattern.
oneDNN supports the floating-point SDPA pattern with data types f32, bf16, and
f16. You can specify the data type via the input and output logical tensors'
data type fields for each operation. oneDNN does not support mixing different
floating data types in a floating-point SDPA pattern.

oneDNN supports the quantized SDPA pattern with int8-f32 mixed precision,
int8-bf16 mixed precision, and int8-f16 mixed precision data types.
Expand All @@ -91,14 +91,13 @@ platforms follow the general description in @ref dev_guide_data_types.

1. oneDNN primitive-based SDPA is implemented as the reference implementation on
both Intel Architecture Processors and Intel Graphics Products. In this case,
floating point SDPA patterns are usually implemented with f32/bf16/f16 matmul
(with post-ops) and softmax primitives, while quantized SDPA patterns are
implemented with int8 matmul (with post-ops) and f32/bf16/f16 softmax
primitives. The reference implementation requires memory to store the
floating-point SDPA patterns are usually implemented with f32, bf16, or f16
matmul (with post-ops) and softmax primitives, while quantized SDPA patterns
are implemented with int8 matmul (with post-ops) and f32, bf16, or f16
softmax primitives. The reference implementation requires memory to store the
intermediate results of the dot products between Query and Key which takes
\f$O(S^2)\f$ memory. It may lead to Out-of-Memory when computing long
\f$O(S^2)\f$ memory. It may lead to out-of-memory error when computing long
sequence length input on platforms with limited memory.

2. The SDPA patterns functionally supports all input shapes meeting the shape
requirements of each operation in the graph. For example, Add, Multiply,
Divide, and Select operations require the input tensors to have the same
Expand All @@ -114,20 +113,20 @@ platforms follow the general description in @ref dev_guide_data_types.
4. GPU
- Optimized implementation is available for 4D Q/K/V tensors with shape
defined as (N, H, S, D).
- Optimized implementation is available for floating point SDPA with `f16`
- Optimized implementation is available for floating-point SDPA with `f16`
data type and `D <= 256` on Intel Graphics Products with Intel(R) Xe Matrix
Extensions (Intel(R) XMX) support.

## Example

oneDNN provides an [SDPA
example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/sdpa.cpp)
demonstrating how to construct a typical floating point SDPA pattern with oneDNN
demonstrating how to construct a typical floating-point SDPA pattern with oneDNN
Graph API on CPU and GPU with different runtimes.

oneDNN also provides a [MQA (Multi-Query Attention)
example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/mqa.cpp) [3]
demonstrating how to construct a floating point MQA pattern with the same
demonstrating how to construct a floating-point MQA pattern with the same
pattern structure as in the SDPA example but different head number in Key and
Value tensors. In MQA, the head number of Key and Value is always one.

Expand Down

0 comments on commit e5933c4

Please sign in to comment.