Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ORTModule] Symbolic Shape Support for Triton Codegen #18317

Merged
merged 5 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions orttraining/orttraining/python/training/ort_triton/_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,9 @@ class TritonCodegen(NodeVisitor):
Specialized codegen for Triton backend.
"""

def __init__(self):
super().__init__()

def codegen(self, node: IRNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int):
func = getattr(self, node.__class__.__name__)
assert func is not None, "unimplemented node: %s" % node.__class__.__name__
assert func is not None, f"unimplemented node: {node.__class__.__name__}"
func(node, context, code_buffer, indent)

def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -125,18 +122,29 @@ def IONode(self, node: IONode, context: CodegenContext, code_buffer: CodeBuffer,
def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int):
is_reduction = node.offset_calc.is_reduction
space_indent = " " * indent
autotune_configs_str = ""
for config in node.offset_calc.autotune_configs.configs:
if is_reduction:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, '
f"num_warps={config[2]}),\n"
)
else:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n'
)
keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"'

if len(node.offset_calc.autotune_configs.configs) > 1:
autotune_configs_str = ""
for config in node.offset_calc.autotune_configs.configs:
if is_reduction:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, '
f"num_warps={config[2]}),\n"
)
else:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n'
)
keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"'
code_buffer += (
f"{space_indent}@triton.autotune(\n"
f"{space_indent} configs=[\n"
f"{autotune_configs_str}"
f"{space_indent} ],\n"
f"{space_indent} key=[{keys_str}],\n"
f"{space_indent})\n"
)

input_args = [context.get_variable_name(input.name) for input in node.inputs]
input_args_str = ", ".join(input_args)
if input_args_str:
Expand All @@ -158,12 +166,6 @@ def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_
)

code_buffer += (
f"{space_indent}@triton.autotune(\n"
f"{space_indent} configs=[\n"
f"{autotune_configs_str}"
f"{space_indent} ],\n"
f"{space_indent} key=[{keys_str}],\n"
f"{space_indent})\n"
f"{space_indent}@triton.jit\n"
f"{space_indent}def {node.name}({input_args_str}{output_args_str}{other_input_args}{blocks_str}):\n"
)
Expand All @@ -175,8 +177,10 @@ def ElementwiseKernelNode( # noqa: N802
offset_calc = node.offset_calc
indent += 4
space_indent = " " * indent
x_numel_str = str(offset_calc.x_numel)
if x_numel_str.isnumeric():
code_buffer += f"{space_indent}xnumel = {x_numel_str}\n"
code_buffer += (
f"{space_indent}xnumel = {offset_calc.x_numel}\n"
f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n"
f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)\n"
)
Expand Down Expand Up @@ -207,9 +211,13 @@ def ReduceKernelNode( # noqa: N802
offset_calc = node.offset_calc
indent += 4
space_indent = " " * indent
x_numel_str = str(offset_calc.x_numel)
if x_numel_str.isnumeric():
code_buffer += f"{space_indent}xnumel = {x_numel_str}\n"
r_numel_str = str(offset_calc.r_numel)
if r_numel_str.isnumeric():
code_buffer += f"{space_indent}rnumel = {r_numel_str}\n"
code_buffer += (
f"{space_indent}xnumel = {offset_calc.x_numel}\n"
f"{space_indent}rnumel = {offset_calc.r_numel}\n"
f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n"
f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n"
f"{space_indent}rbase = tl.arange(0, RBLOCK)[None, :]\n"
Expand Down Expand Up @@ -444,6 +452,13 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod
indent += 4
space_indent = " " * indent

seen_symbolic_shape = set()
for input in node.inputs:
for idx, dim in enumerate(input.shape):
if dim.is_symbol and dim not in seen_symbolic_shape:
code_buffer += f"{space_indent}{dim} = {context.get_variable_name(input.name)}.size()[{idx}]\n"
seen_symbolic_shape.add(dim)

if node.has_dropout:
code_buffer += (
f'{space_indent}seed_cuda = torch.randint(2**31, size=(), dtype=torch.int64, device="cuda")\n\n'
Expand All @@ -470,18 +485,31 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod
if kernel_node.has_dropout:
kernel_args_str += ", seed_cuda"

# Support symbolic shape if any.
symbolic_shape_args_str = ", ".join(kernel_node.symbolic_shape_variables)
if symbolic_shape_args_str:
kernel_args_str += f", {symbolic_shape_args_str}"

block_str = ""
if len(kernel_node.offset_calc.autotune_configs.configs) == 1:
config = kernel_node.offset_calc.autotune_configs.configs[0]
if kernel_node.offset_calc.is_reduction:
block_str = f", XBLOCK={config[0]}, RBLOCK={config[1]}, num_warps={config[2]}"
else:
block_str = f", XBLOCK={config[0]}, num_warps={config[2]}"

if isinstance(kernel_node, ReduceKernelNode):
code_buffer += (
f"{space_indent}x_numel = {kernel_node.offset_calc.x_numel}\n"
f"{space_indent}r_numel = {kernel_node.offset_calc.r_numel}\n"
f'{space_indent}grid = lambda meta: (triton.cdiv(x_numel, meta["XBLOCK"]),)\n'
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel)\n"
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel{block_str})\n"
)
else:
code_buffer += (
f"{space_indent}n_elements = {kernel_node.offset_calc.x_numel}\n"
f'{space_indent}grid = lambda meta: (triton.cdiv(n_elements, meta["XBLOCK"]),)\n'
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements)\n"
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements{block_str})\n"
)

for name in node.cross_kernel_args_to_delete[idx]:
Expand Down
116 changes: 95 additions & 21 deletions orttraining/orttraining/python/training/ort_triton/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import sympy
from onnx import GraphProto, NodeProto, TensorProto

from ._sympy_utils import parse_shape
from ._sympy_utils import extract_shape_from_symbol
from ._utils import get_attribute, get_reduce_info, next_power_of_2

_SPECIAL_FLOATS: List[str] = ["inf", "-inf"]


class CodegenContext:
"""
Expand All @@ -28,7 +30,8 @@ def get_variable_name(self, name: str) -> str:
# For some operators such as data load/store, we need an internal variable name inside the kernel function.
def get_internal_variable_name(self, name: str) -> str:
var_name = self._var_map[name]
return self._var_map[var_name] if var_name in self._var_map else var_name
var_name = self._var_map[var_name] if var_name in self._var_map else var_name
return f'float("{var_name}")' if var_name in _SPECIAL_FLOATS else var_name


class CodeBuffer:
Expand All @@ -49,14 +52,38 @@ def codegen(self, node: Any, context: CodegenContext, code_buffer: CodeBuffer, i
pass


class SymbolicDSU:
"""
A 'disjoint set union' to merge symbolics so that we use less variables in the generated code.
When handling shape inference for elementwise Ops, if two symbols are not equal and they are not 1, we merge them.
"""

def __init__(self):
self._dsu: Dict[sympy.Expr, sympy.Expr] = {}

def find(self, symbolic: sympy.Expr) -> sympy.Expr:
if symbolic not in self._dsu:
self._dsu[symbolic] = symbolic
return symbolic
if symbolic == self._dsu[symbolic]:
return symbolic
self._dsu[symbolic] = self.find(self._dsu[symbolic])
return self._dsu[symbolic]

def union(self, symbolic: sympy.Expr, other_symbolic: sympy.Expr):
root = self.find(symbolic)
other_root = self.find(other_symbolic)
self._dsu[other_root] = root


class TensorInfo:
"""
Represent a input/output tensor of a node.
"""

def __init__(self, dtype: TensorProto.DataType, shape: List[Any]):
def __init__(self, dtype: TensorProto.DataType, shape: List[sympy.Expr]):
self._dtype: TensorProto.DataType = dtype
self._shape: List[sympy.Expr] = parse_shape(shape)
self._shape: List[sympy.Expr] = shape

@property
def dtype(self) -> TensorProto.DataType:
Expand All @@ -66,27 +93,42 @@ def dtype(self) -> TensorProto.DataType:
def shape(self) -> List[sympy.Expr]:
return self._shape

def update_shape(self, symbolics: SymbolicDSU):
self._shape = [symbolics.find(dim) if dim.is_symbol else dim for dim in self._shape]


def _infer_elementwise_shape(input_infos: List[TensorInfo]) -> List[sympy.Expr]:
def _infer_elementwise_shape(input_infos: List[TensorInfo], symbolics: SymbolicDSU) -> List[sympy.Expr]:
max_len = max([len(input_info.shape) for input_info in input_infos])
output_shape: List[sympy.Expr] = [sympy.Integer(1)] * max_len
for input_info in input_infos:
offset = max_len - len(input_info.shape)
for i in range(len(input_info.shape)):
if not input_info.shape[i].is_number or input_info.shape[i] != 1:
output_shape[i + offset] = input_info.shape[i]
for idx, dim in enumerate(input_info.shape):
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
if not dim.is_number or dim != 1:
if not output_shape[idx + offset].is_number or output_shape[idx + offset] != 1:
symbolics.union(output_shape[idx + offset], dim)
else:
output_shape[idx + offset] = dim
return output_shape


def _infer_elementwise(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos))]
def _infer_elementwise(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos, symbolics))]


def _infer_where(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos))]
def _infer_where(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos, symbolics))]


def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_reduction(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
input_rank = len(input_infos[0].shape)
keep_dims, axes = get_reduce_info(node, graph, input_rank)
axes = [axis + input_rank if axis < 0 else axis for axis in axes]
Expand All @@ -98,17 +140,26 @@ def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: Grap
return [TensorInfo(input_infos[0].dtype, shape)]


def _infer_unary(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_unary(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [input_infos[0]]


def _infer_cast(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_cast(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
dtype = get_attribute(node, "to", TensorProto.UNDEFINED)
assert dtype != TensorProto.UNDEFINED
return [TensorInfo(dtype, input_infos[0].shape)]


def _infer_dropout(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_dropout(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [input_infos[0], TensorInfo(TensorProto.BOOL, input_infos[0].shape)]


Expand Down Expand Up @@ -138,10 +189,12 @@ class TypeAndShapeInfer:
}

@classmethod
def infer(cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def infer(
cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
if node.op_type not in cls._INFER_FUNC_MAP:
raise NotImplementedError(f"Unsupported op type: {node.op_type}")
return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph)
return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph, symbolics)


class AutotuneConfigs:
Expand All @@ -152,9 +205,30 @@ class AutotuneConfigs:
If it's reduction kernel on last contiguous dimensions, the contiguous flag is True.
"""

def __init__(self, x_numel: int, r_numel: int, contiguous: bool):
self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel, r_numel, contiguous)
self.requires_for_loop: bool = any(config[1] < r_numel for config in self.configs)
def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool):
x_numel_int = (
int(x_numel)
if x_numel.is_number
else int(
x_numel.subs(
{symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in x_numel.free_symbols}
)
)
)
r_numel_int = (
int(r_numel)
if r_numel.is_number
else int(
r_numel.subs(
{symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in r_numel.free_symbols}
)
)
)
self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel_int, r_numel_int, contiguous)
# If there is symbolic shape, we will not tune the kernel.
if not x_numel.is_number or not r_numel.is_number:
centwang marked this conversation as resolved.
Show resolved Hide resolved
self.configs = self.configs[-1:]
self.requires_for_loop: bool = any(config[1] < r_numel_int for config in self.configs)

def _num_warps(self, x: int, r: int) -> int:
return min(max(x * r // 256, 2), 8)
Expand Down
15 changes: 11 additions & 4 deletions orttraining/orttraining/python/training/ort_triton/_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _get_dtype_and_shape(self, arg_name: str, **kwargs):
arg_info = node_arg_infos[arg_name]
return arg_info.dtype, arg_info.shape

def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, **kwargs):
def _decompose_elementwise_precision(self, node: NodeProto, **kwargs):
x = node.input[0]
dtype, _ = self._get_dtype_and_shape(x, **kwargs)
if not _is_half_dtype(dtype):
Expand All @@ -79,15 +79,19 @@ def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, *
return [*cast_nodes, op_node, cast_node1]

def Exp(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return self._decompose_elementwise_precision(node, graph, **kwargs)
# pylint: disable=unused-argument
return self._decompose_elementwise_precision(node, **kwargs)

def Pow(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return self._decompose_elementwise_precision(node, graph, **kwargs)
# pylint: disable=unused-argument
return self._decompose_elementwise_precision(node, **kwargs)

def Sqrt(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return self._decompose_elementwise_precision(node, graph, **kwargs)
# pylint: disable=unused-argument
return self._decompose_elementwise_precision(node, **kwargs)

def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
x = node.input[0]
w = node.input[1]
Expand Down Expand Up @@ -153,6 +157,7 @@ def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # n
]

def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
dy = node.input[0]
x = node.input[1]
Expand Down Expand Up @@ -241,6 +246,7 @@ def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs):
return decomposed_nodes

def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
x = node.input[0]
y = node.output[0]
Expand All @@ -259,6 +265,7 @@ def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return [max_node, sub_node, exp_node, sum_node, div_node]

def SoftmaxGrad_13(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
dy = node.input[0]
y = node.input[1]
Expand Down
Loading
Loading