Use TLE-Raw#
This section introduces how to use TLE-Raw. TLE-Raw is available on trition_3.6.x branch.
MLIR#
The following is an example of MLIR (Multi-Level Intermediate Representation).
from typing import Annotated
from mlir import ir
from mlir.dialects import arith, nvvm, tensor
import triton.language as tl
from triton.experimental.flagtree.edsl import dialect
import triton.experimental.flagtree.language as fl
# 1. Dialect declaration
@tle.raw.language(name="mlir")
# 2. Hardware constraint
@tle.hardware_constraint(threads_dim=1, sync_scope="block")
# 3. Function implementation
def vector_add_tile(
x: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
y: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
output: Annotated[ir.RankedTensorType, "tensor<1024xf32>"]
):
# Write low-level operations directly using the MLIR Python bindings
tidx = nvvm.ThreadIdXOp(ir.IntegerType.get_signless(32)).res
bidx = nvvm.BlockIdXOp(ir.IntegerType.get_signless(32)).res
bdimx = nvvm.BlockDimXOp(ir.IntegerType.get_signless(32)).res
idx = arith.addi(arith.muli(bidx, bdimx), tidx)
idx = arith.index_cast(ir.IndexType.get(), idx)
xval = tensor.extract(x, [idx])
yval = tensor.extract(y, [idx])
result = arith.addf(xval, yval)
tensor.insert(result, output, [idx])
@tle.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
# Tile language main code
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = tl.zeros_like(x)
# 4. Function call
tle.call(
vector_add_tile,
args=[x, y, output],
hardware={
"threads": (BLOCK_SIZE,), # Must satisfies threads_dim=1
},
layout={
x: {"space": "shared", "order": [0]}, # Shared memory, one-dimensional layout (for optimizing connection)
y: {"space": "shared", "order": [0]},
output: {"space": "shared", "order": [0]}
}
)
tl.store(output_ptr + offsets, output, mask=mask)
TLE-raw consists of the following four parts:
Dialect declaration (decorator)
Decorator:
@tle.raw.language(name="mlir")Explanation: This decorator marks the function
vector_add_tileas a block of code written directly in the MLIR dialect. It tells the compiler, specifically through the FlagTree EDSL (Embedded Domain Specific Language), that the body of this function should be interpreted and lowered using MLIR operations (such asnvvm,arith, andtensor), rather than standard Python or Triton operations.
Hardware constraint (decorator)
Decorator:
@tle.hardware_constraint(threads_dim=1, sync_scope="block")Explanation: This decorator imposes constraints on the hardware execution model for the
vector_add_tilefunction. It specifies that the function operates in a 1-dimensional thread space (threads_dim=1) and that synchronization primitives should be scoped at the block level (sync_scope="block").
Function implementation
Function:
vector_add_tile(...)Explanation: This is the actual implementation of the computation kernel written using low-level MLIR Python bindings. It defines the specific operations (thread indexing, memory loading, floating-point addition, and memory storing) that will be executed by the hardware. The function signature uses Annotated types to explicitly define the input and output as
tensor<1024xf32>(1024-element float32 tensors), ensuring the compiler knows the exact data layout and types to expect.
Function call
Invocation:
tle.call(vector_add_tile, args=[x, y, output], hardware={...}, layout={...})Explanation: This line invokes the declared MLIR function (
vector_add_tile) from within the high-level Triton kernel (add_kernel). It passes the input tensorsx,y, and the output buffer. Crucially, it provides hardware mapping hints (defining the number of threads) and memory layout specifications (defining the tensors as residing in “shared” memory with a specific order). This allows the compiler to bridge the gap between the high-leveltl.load/tl.storeoperations and the low-level MLIR IR generation.