| name | tlx-api-reference |
| description | TLX DSL API reference for low-level GPU primitives. Use when writing or modifying TLX kernel code that uses barriers (mbarrier, named barriers), memory allocation (local_alloc, SMEM, TMEM), TMA operations, warp specialization (async_tasks, async_task), CLC (cluster launch control), or wgmma instructions. Covers Hopper and Blackwell hardware differences.
|
TLX API Quick Reference
Warp Specialization
| Function | Description | Arch |
|---|
tlx.async_tasks() | Context manager wrapping all async task regions | Both |
tlx.async_task([task_ids]) | Assign code to specific task IDs (e.g., [0] = producer, [1,2] = consumers) | Both |
tlx.async_task(num_warps=N, num_regs=R) | Explicit warp/register allocation for a task | Both |
tlx.async_task("default", num_regs=R) | Default task for code outside explicit tasks | Both |
tlx.async_task_replica_id() | Returns replica ID inside an async region | Both |
Warp specialization skeleton
with tlx.async_tasks():
with tlx.async_task([0]):
with tlx.async_task([1, 2]):
Memory Barriers
mbarrier (shared-memory allocated)
| Function | Description | Arch |
|---|
tlx.alloc_barriers(num_barriers, arrive_count=1) | Allocate SMEM barriers and initialize with arrive count | Both |
tlx.barrier_expect_bytes(bar, bytes, pred=None) | Set expected transaction byte count on barrier | Both |
tlx.barrier_wait(bar, phase, pred=None) | Wait until barrier phase flips (LOCAL mbarrier only) | Both |
tlx.barrier_arrive(bar, arrive_count=1, remote_cta_rank=None) | Signal arrival at barrier. remote_cta_rank signals a barrier in a remote CTA — only valid when ctas_per_cga > 1, causes "Unexpected buffer remote view in 1cta mode" otherwise. Guard with if USE_2CTA: when kernel supports both modes. | Both |
tlx.cluster_barrier() | Full cluster-wide synchronization barrier | Both |
arrive_count rules:
- Implicit arrive from
barrier_expect_bytes: use arrive_count=1
barrier_arrive inside tlx.async_task: arrive_count = number of warp groups
barrier_arrive outside tlx.async_task: arrive_count=1 (only tid==0 arrives)
Named barriers (hardware-allocated, indices 0–15)
| Function | Description | Arch |
|---|
tlx.named_barrier_wait(bar_id, num_threads) | Wait until num_threads arrive at bar_id | NVIDIA |
tlx.named_barrier_arrive(bar_id, num_threads) | Signal arrival at bar_id | NVIDIA |
num_threads must be a multiple of 32 (warp size). Typically num_warp_groups * warps_per_group * 32.
Used for PingPong scheduling to prevent tensor core contention between consumer warp groups.
Memory Operations
SMEM / TMEM allocation
| Function | Description | Arch |
|---|
tlx.local_alloc(shape, dtype, num, storage=smem, reuse=None, layout=None) | Allocate buffered tensor in SMEM or TMEM | Both (TMEM: Blackwell) |
tlx.storage_alias_spec(storage=smem, buffer_size_bytes=None) | Define shared buffer region for multiple local_alloc calls via reuse | Both |
tlx.local_view(buf, index) | Get view of a single buffer from a multi-buffered tensor | Both |
tlx.local_slice(buf, start, end) | Slice a sub-range of a buffered tensor | Both |
tlx.subslice(tensor, dim, start, size) | Subslice a tensor along a dimension | Both |
tlx.local_load(buf) | Load from SMEM/TMEM buffer into registers | Both |
tlx.local_store(val, buf) | Store from registers into SMEM/TMEM buffer | Both |
tlx.local_trans(buf) | Transpose a shared memory buffer | Both |
tlx.local_reinterpret(buf, dtype) | Reinterpret buffer with a different dtype | Both |
tlx.remote_view(buf, remote_cta_rank) | Get view of buffer in a remote CTA's SMEM | Both |
tlx.remote_shmem_store(val, buf) | Store to remote CTA's shared memory | Both |
tlx.async_remote_shmem_store(val, buf) | Async store to remote CTA's shared memory | Both |
tlx.tmem_copy(src, dst) | Copy between TMEM buffers | Blackwell |
tlx.fence_async_shared() | Memory fence for async shared memory operations | Both |
Storage kinds: tlx.storage_kind.smem, tlx.storage_kind.tmem (Blackwell), tlx.storage_kind.smemCluster
TMA (Tensor Memory Accelerator)
| Function | Description | Arch |
|---|
tlx.make_tensor_descriptor(ptr, shape, strides, block_shape) | Create TMA descriptor from pointer (host-side) | Hopper+ |
tlx.allocate_tensor_descriptor(ptr, shape, strides, block_shape, swizzle_mode) | Allocate and fill TMA descriptor in SMEM | Hopper+ |
tlx.reinterpret_tensor_descriptor(desc, dtype) | Reinterpret TMA descriptor with different dtype | Hopper+ |
tlx.async_descriptor_load(desc, indices, barrier=None) | Async TMA load from global → SMEM, tracked by barrier | Hopper+ |
tlx.async_descriptor_store(desc, val, indices) | Async TMA store from registers → global | Hopper+ |
tlx.async_descriptor_store_wait() | Wait for all pending TMA stores to complete | Hopper+ |
tlx.async_load(ptr, buf, barrier) | Async bulk copy global → SMEM (cp.async) | Hopper+ |
tlx.async_load_commit_group() | Commit async load group | Hopper+ |
tlx.async_load_wait_group(n) | Wait for async load groups (n pending allowed) | Hopper+ |
Matrix Multiply (MMA)
| Function | Description | Arch |
|---|
tlx.async_dot(A, B, acc=None, use_acc=None, mBarriers=[], two_ctas=False) | Warp-group MMA: D = A @ B + C. Maps to wgmma (Hopper) or tcgen05.mma (Blackwell) | Both |
tlx.async_dot_scaled(A, B, acc, A_scale, A_format, B_scale, B_format, ...) | Scaled MMA with FP8 inputs: D = (Ascale_A) @ (Bscale_B) + D | Blackwell |
tlx.async_dot_wait(pendings, inp) | Wait for N pending async dot operations to complete | Both |
tlx.tcgen05_commit(mBarrier, two_ctas=False) | Make mbarrier track completion of prior tcgen05 ops. Use a SEPARATE mbarrier from async_dot | Blackwell |
Minimum tile sizes for async_dot: M ≥ 64, K ≥ 16, N ≥ 32
Pair-CTA MMA (two_ctas=True): M must be 128 per CTA.
Multi-CTA (Cluster) Kernels
ctas_per_cga=(N,1,1) in triton.Config sets the cluster size. The grid
specifies total CTAs; hardware divides by ctas_per_cga to get the number
of clusters. E.g., grid=(2,1,1) with ctas_per_cga=(2,1,1) = 1 cluster of
2 CTAs.
input_precision options: tf32, tf32x3, ieee
CLC (Cluster Launch Control) — Blackwell only
| Function | Description |
|---|
tlx.clc_create_context(num_consumers, num_stages=1) | Create CLC pipeline context (allocates barriers + response buffers) |
tlx.clc_producer(context, p_producer, multi_ctas=False, k=0) | Issue CLC try_cancel request from CTA 0 |
tlx.clc_consumer(context, p_consumer, multi_ctas=False, k=0) | Decode tile ID from CLC response, signal completion. Returns tile_id or -1 |
For 2-CTA mode: set multi_ctas=True (uses "arrive remote, wait local" pattern).
Utility
| Function | Description | Arch |
|---|
tlx.cluster_cta_rank() | Unique CTA ID within a cluster (all dims) | Both |
tlx.thread_id(axis) | Thread ID along axis 0, 1, or 2 | Both |
tlx.dtype_of(tensor_or_desc) | Get element type of tensor or tensor descriptor | Both |
tlx.size_of(dtype) | Size of dtype in bytes | Both |
tlx.get_fp8_format_name(dtype) | Get FP8 format string ("e5m2" or "e4m3") for scaled MMA | Both |
tlx.clock64() | 64-bit hardware clock value (for timing) | Both |
tlx.stoch_round(src, dst_ty, rand_bits) | Hardware stochastic rounding FP32 → FP8/BF16/F16 | Blackwell |
Common patterns
Producer-consumer with mbarrier (pipelined GEMM)
bars_full = tlx.alloc_barriers(num_stages, arrive_count=1)
bars_empty = tlx.alloc_barriers(num_stages, arrive_count=num_consumers)
tlx.barrier_expect_bytes(bar_full, nbytes)
tlx.async_descriptor_load(desc, indices, barrier=bar_full)
tlx.barrier_wait(bar_full, phase)
tlx.async_dot(A, B, acc)
tlx.barrier_arrive(bar_empty)
PingPong with named barriers
tlx.named_barrier_wait(9, 256)
qk = tlx.async_dot(q, k)
tlx.named_barrier_arrive(10, 256)
tlx.named_barrier_arrive(9, 256)
tlx.named_barrier_wait(10, 256)
qk = tlx.async_dot(q, k)
Deep-dive docs
- API reference:
third_party/tlx/README.md
- Barriers:
third_party/tlx/doc/tlx_barriers.md
- Placeholder layouts:
third_party/tlx/doc/PlaceholderLayouts.md
- Storage alias design:
third_party/tlx/doc/storage_alias_spec_design.md