Use TLE-Lite#
This section introduces how to use TLE-Lite. TLE-Lite is available on trition_3.6.x branch.
Memory management#
You can use the following operations to manage the memory.
tle.load#
tle.load loads a tensor asynchronously from GMEM. It supports asynchronously hint.
x = tle.load(..., is_async=True)
Tensor slicing#
Splits the input tensor into a grid of sub-tiles based on the specified sub-tile shape, and extracts the sub-tile at the given coordinates. GPU: Supports extraction into registers and shared memory.
tle.extract_tile#
Splits the input tensor into a grid of sub-tiles based on the specified sub-tile shape, and extracts the sub-tile at the given coordinates.
Supports extraction into registers and shared memory.
# x is [4, 4]
# z is [2, 2]
# Split x into a sub-tile grid with shape=[2, 2], and extract the sub-tile at [0, 0]
z = x.extract_tile(index=[0, 0], shape=[2, 2])
tle.insert_tile#
Splits the input tensor into a grid of sub-tiles based on the sub-tile shape, and updates the sub-tile at the specified coordinates with a new tile.
Supports updates from registers and shared memory.
# x is [4, 4], y is [2, 2], z is [4, 4]
# Split x into sub-tiles of shape=[2, 2], update the [0, 0] sub-tile with y, and return the full [4, 4] tensor
z = x.insert_tile(y, index=[0, 0])
Distribution#
The Triton distributed API consists of four core parts: device mesh definition, sharding specification description, synchronization, and remote access (point-to-point communication).
device mesh#
tle.device_mesh#
tle.device_mesh defines the topological structure of physical devices. It is the fundamental context for all distributed operations.
class device_mesh:
def __init__(self, topology: dict):
"""
Initialize a DeviceMesh.
Args:
topology (dict): A dictionary describing the hardware hierarchy.
Keys are level names; values are either an integer (for 1D)
or a list of tuples (for multi-dimensional levels).
"""
self._physical_ids = ... # Internal storage: flattened list of physical IDs (0..N-1)
self._shape = ... # Shape of the current logical view, e.g., (2, 2, 4, 2, 2, 4)
self._dim_names = ... # Names of the current dimensions
# Initialization and parsing logic...
@property
def shape(self):
"""Return the logical shape of the current mesh."""
return self._shape
@property
def ndim(self):
"""Return the number of dimensions."""
return len(self._shape)
def flatten(self):
"""
Flatten the mesh into 1D. Commonly used for ring-based communication patterns.
"""
return self.reshape(prod(self._shape))
def __getitem__(self, key):
"""
Support slicing operations and return a sub-mesh.
Supports standard slices (slice objects) and integer indexing.
"""
# Compute new shape and selected physical IDs after slicing
# ...
return sub_mesh
def __repr__(self):
return f"DeviceMesh(shape={self._shape}, names={self._dim_names})"
# Define a complex hardware hierarchy
topology = {
# Inter-node level (2x2 = 4 nodes)
"node": [("node_x", 2), ("node_y", 2)],
# Intra-node GPUs (4 devices)
"device": 4,
# Intra-GPU clusters (2x2)
"block_cluster": [("cluster_x", 2), ("cluster_y", 2)],
# Blocks within each cluster (4 blocks)
"block": 4
}
# mesh.shape -> (2, 2, 4, 2, 2, 4)
# Total size = 256
mesh = tle.device_mesh(topology=topology)
Synchronization#
In complex distributed operators—such as Ring-AllReduce or pipelined execution with independent row/column communication—we often need to synchronize only thread blocks within the same “row” or “column,” rather than across the entire cluster. A global synchronization would introduce unnecessary waiting overhead. This API supports sub-mesh synchronization, meaning that within a large physical cluster, we can define multiple logical “communication groups” and perform synchronization independently within each group.
def distributed_barrier(mesh):
"""
If a sub-mesh is passed, only devices within that sub-mesh are synchronized.
Devices outside the sub-mesh should treat this instruction as a no-op
(or the compiler should ensure their control flow never reaches this point).
"""
pass
tle.distributed_barrier#
tle.distributed_barrier synchronize only the set of devices corresponding to the given mesh or sub-mesh.
Read from neighboring shards (ring-style exchange).
node_rank = tle.shard_id(mesh, "node")
device_rank = tle.shard_id(mesh, "device")
next_device = (device_rank + 1) % mesh.shape[1]
remote_x = tle.remote(x, shard_id=(node_rank, next_device), scope=mesh)
tle.distributed_barrier(mesh)
neighbor_vals = tl.load(remote_x)
Remote access#
tle.remote is used to obtain a handle to a tensor located on another device. This corresponds to point-to-point communication or direct memory access (e.g., RDMA/NVLink Load). It enables kernels to explicitly access data from a specific shard.
def remote(tensor, shard_id, scope):
"""
Obtains a handle to a Remote Tensor residing on a specific device shard.
:param tensor: A logically distributed tensor (already annotated with tle.sharding).
:param shard_id: tuple. The coordinates of the target device within the Device Mesh.
For example, if mesh=(2,4) and shard_id=(0, 3), this refers to GPU #3 on node #0.
:return: RemoteTensor. Supports operations such as load, store, etc.
"""
tle.remote: Explicitly read from or write to remote shards.
node_rank = tle.shard_id(mesh, "node")
device_rank = tle.shard_id(mesh, "device")
next_device = (device_rank + 1) % mesh.shape[1]
remote_x = tle.remote(x, shard_id=(node_rank, next_device), scope=mesh)
tle.distributed_barrier(mesh)
neighbor_vals = tl.load(remote_x)
Primitives interactive with local_ptr#
The following APIs are used together with tle.gpu.local_ptr. For more information, see Use TLE-Struct.
tl.load(for local_ptr)tl.store(for local_ptr)tl.atomic_add/and/cas/max/min/or/xchg/xor(for local_ptr)