com um clique
mpk-internals
// Reference guide for the MPK compilation-to-runtime pipeline. Use when asked how MPK works internally, how compilation/code generation works, what happens at runtime, or when debugging the megakernel scheduler.
// Reference guide for the MPK compilation-to-runtime pipeline. Use when asked how MPK works internally, how compilation/code generation works, what happens at runtime, or when debugging the megakernel scheduler.
Guide for adding a new model (e.g., Llama4, DeepSeek V3) to the MPK persistent kernel. Covers prerequisites check, demo structure, layer wiring, and testing.
Step-by-step guide for adding a new task implementation to Mirage Persistent Kernel (MPK). Use this when adding a new GPU operator (e.g., a new attention variant, normalization, activation) to the MPK megakernel.
Guide for using MPK test mode to unit-test individual layers or multi-layer pipelines through the full compilation pipeline. Use when writing layer tests, debugging kernel output, or validating a new task end-to-end.
| name | mpk-internals |
| description | Reference guide for the MPK compilation-to-runtime pipeline. Use when asked how MPK works internally, how compilation/code generation works, what happens at runtime, or when debugging the megakernel scheduler. |
This document traces the full lifecycle of an MPK megakernel from Python graph construction through CUDA compilation to persistent kernel execution.
Phase 1: Python Graph Building
PersistentKernel.compile()
→ layer methods build KNGraph/TBGraph
→ kn_graph.generate_task_graph()
|
v
Phase 2: C++ Code Generation (runtime.cc)
Graph::generate_task_graph()
→ register_mugraph() — builds task/event lists
→ print_task_graph() — emits CUDA code + JSON
|
v
Two artifacts:
test.cu — _init_persistent_kernel(), _execute_task(), Python C ext
task_graph.json — task descriptors, events, dependencies
|
v
Phase 3: CUDA Compilation
nvcc test.cu → test.so (Python extension module: __mirage_launcher)
|
v
Phase 4: Runtime Initialization
init_persistent_kernel()
→ loads JSON, allocates GPU queues, builds RuntimeConfig
|
v
Phase 5: Runtime Execution
launch_persistent_kernel()
→ prepare_kernel (reset queues)
→ worker_kernel + scheduler_kernel (persistent loop)
→ workers fetch tasks, wait on events, call _execute_task()
→ schedulers process events, enqueue tasks to workers
python/mirage/mpk/persistent_kernel.pyEntry point: PersistentKernel.compile()
The compilation method does the following in order:
Generate task graph — calls self.kn_graph.generate_task_graph(num_gpus, my_gpu_id) which bridges through Cython (python/mirage/_cython/core.pyx, generate_task_graph()) into C++. Returns {"cuda_code": str, "json_file": str}.
Write files — writes test.cu (CUDA code + HARD_CODE Python extension wrapper) and task_graph.json to a temp directory.
Compile — builds the nvcc command via get_compile_command() and calls subprocess.check_call().
Load module — uses importlib.util.spec_from_file_location() to dynamically load the compiled .so as Python module __mirage_launcher. Extracts init_func, launch_func, init_request_func, finalize_func.
Initialize runtime — calls init_func(...) with meta-tensor pointers, worker/scheduler counts, and serving config.
Each layer method (e.g., rmsnorm_layer, linear_layer, moe_w13_fp8_layer) does:
TBGraph with CyTBGraph(grid_dim, block_dim, forloop_range, reduction_dimx)tb_graph.new_input(dtensor, partition, forloop_dim, store_in_dmem) for each input and outputself.kn_graph.customized([tensors...], tb_graph) to register the operatorself.kn_graph.register_task(tb_graph, "task_name") which dispatches to C++ Graph::register_task()The HARD_CODE constant (top of persistent_kernel.py) is a C string appended to the generated CUDA code. It defines a Python extension module with four functions:
init_func — parses Python args, calls C++ init_persistent_kernel()launch_func — takes a CUDA stream pointer, calls launch_persistent_kernel(stream)init_request_func — calls init_request_resources() (for online serving)finalize_func — calls finalize_persistent_kernel()Each layer method (e.g., rmsnorm_layer, linear_layer, moe_w13_linear_layer) builds a TBGraph that describes how the global tensors are sliced into per-task tiles. This section explains every parameter.
CyTBGraph constructortb_graph = TBGraph(CyTBGraph(grid_dim, block_dim, forloop_range, reduction_dimx))
| Parameter | Meaning |
|---|---|
grid_dim | (x, y, z) — number of task instances in each dimension. Total tasks = x * y * z. |
block_dim | (threads, 1, 1) — threads per task. Must be (128,1,1) Ampere, (256,1,1) Hopper/Blackwell. |
forloop_range | Number of forloop iterations (always 1 in MPK — see note below). |
reduction_dimx | Tile size for the reduction dimension (always 64 in MPK). |
tb_graph.new_input() — registering a tensortb_graph.new_input(dtensor, input_map, forloop_dim, store_in_dmem)
Called for every tensor the task touches — both inputs and outputs. The first num_inputs calls register inputs; the remaining register outputs. This ordering must match num_inputs/num_outputs in graph.cc's task_config tuple.
input_map: the partition tupleA 3-element tuple (mx, my, mz) that maps grid dimensions → tensor dimensions:
input_map.x value | Meaning |
|---|---|
-1 | grid_dim.x does not partition this tensor. Every task sees the full extent of every dimension. |
0 | grid_dim.x partitions tensor dimension 0. Task at grid position gx sees the slice [gx * dim[0]/grid_x : (gx+1) * dim[0]/grid_x] along dim 0. |
1 | grid_dim.x partitions tensor dimension 1. Same slicing logic on dim 1. |
2 | grid_dim.x partitions tensor dimension 2. |
input_map.y and input_map.z work identically for grid_dim.y and grid_dim.z.
In short: the value tells you which tensor dimension that grid axis splits. -1 means "don't split by this grid axis."
forloop_dim (vestigial in MPK)In the Mirage superoptimizer, forloop_dim and forloop_range together control tiled reduction loops within a TBGraph. However, in MPK forloop_range is always 1, which makes forloop_dim a no-op — the dimension division (dim / 1) and stride multiplier (* 1) have no effect regardless of what value you pass.
MPK task kernels handle their own internal tiling and reduction directly in CUDA (e.g., looping over the K dimension in a matmul). The TBGraph forloop mechanism is not used. You'll see various forloop_dim values in existing layer methods (e.g., 1, 2, -1), but they're all equivalent when forloop_range=1. By convention, existing code sets forloop_dim to the "reduction dimension" of the operation, but this is cosmetic.
store_in_dmemTrue — the per-task tensor slice lives in device (global) memory. Should be set to True for all MPK tensors.moe_w13_linear_layerdef moe_w13_linear_layer(self, input, weight, moe_routing_indices,
moe_mask, output, grid_dim, block_dim):
# input: (batch_size, hidden_size) 2D bf16
# weight: (num_experts, 2*intermediate_size, hidden_size) 3D bf16
# moe_routing_indices:(num_experts, batch_size) 2D int32
# moe_mask: (num_experts + 1,) 1D int32
# output: (batch_size, num_experts_per_tok, 2*inter_size) 3D bf16
tb_graph = TBGraph(CyTBGraph(grid_dim, block_dim, 1, 64))
# tensor, input_map, forloop_dim*, store_in_dmem
# (* forloop_dim is vestigial in MPK — has no effect when forloop_range=1)
tb_graph.new_input(input, (-1, -1, -1), 1, True)
# → No partition on any grid axis. Every task sees full (batch, hidden).
tb_graph.new_input(weight, (-1, 1, -1), 2, True)
# → grid_dim.y partitions dim 1 (the 2*intermediate_size axis).
# Each task handles 2*inter_size / grid_dim.y rows of the weight matrix.
tb_graph.new_input(moe_routing_indices, (-1, -1, -1), -1, True)
# → No partition. Every task sees the full routing table.
tb_graph.new_input(moe_mask, (-1, -1, -1), -1, True)
# → No partition. Every task sees the full mask.
tb_graph.new_input(output, (-1, 2, -1), -1, True)
# → grid_dim.z partitions dim 2 (the 2*intermediate_size axis of the output).
# Each task writes to its slice of output columns.
self.kn_graph.customized([input, weight, moe_routing_indices, moe_mask, output], tb_graph)
self.kn_graph.register_task(tb_graph, "moe_w13_linear_sm100")
At runtime, the partition tuple is resolved during task graph generation (src/threadblock/graph.cc). For each task instance (one grid coordinate), the code generator computes a byte offset from the tensor's base pointer:
per_task_ptr = base_ptr
+ blockIdx.x * stride_for(input_map.x)
+ blockIdx.y * stride_for(input_map.y)
+ blockIdx.z * stride_for(input_map.z)
These offsets are baked into the TaskDesc at init time (via JSON → FullTaskDesc → TaskDesc). The task kernel receives pre-offset pointers in task_desc->input_ptrs[i] and task_desc->output_ptrs[i] — this is why tasks are blockIdx-agnostic.
src/kernel/runtime.ccEntry point: Graph::generate_task_graph()
This function orchestrates all code generation:
register_mugraph() — walks the KNGraph operators and converts each into FullTaskDesc entries. For each KN_CUSTOMIZED_OP, it queries task_config[op] (a tuple of num_inputs, num_outputs, TaskType, variant_id set by Graph::register_task()) to determine the task type and variant. It also creates EventDesc entries for inter-task dependencies and populates first_tasks (the initial ready tasks).
print_task_graph() — generates two outputs:
Output 1: CUDA code containing three generated functions:
construct_task_graph() — loads task_graph.json at runtime, parses it into FullTaskDesc/EventDesc vectors, and creates TMA descriptors for Hopper/Blackwell tasks._init_persistent_kernel() — sets up tensor pointers from io_configs (torch tensors, cudaMalloc buffers, shuffled tensors, NVSHMEM buffers). Called once during initialization._execute_task() — a giant if/else dispatcher that maps (task_type, variant_id) pairs to the actual kernel function calls. Each branch contains the code string generated by the corresponding TaskRegister::register_*_task() function.Output 2: JSON task graph — serializes all tasks, events, and dependencies (see JSON Schema section below).
src/kernel/graph.ccGraph::register_task() maps task name strings to registration functions:
"moe_w13_fp8_sm100" → register_moe_fp8_sm100_task() → TASK_MOE_W13_FP8_SM100
Each registration function (in src/kernel/task_register.cc) reads tensor dimensions from the TBGraph, generates a CUDA code string calling the templated kernel with those dimensions, and returns a variant_id via register_task_variant(). Same code string → same variant_id (deduplication).
get_compile_command() in persistent_kernel.pyBuilds the nvcc command with:
-gencode=arch=compute_90a,code=sm_90a (Hopper), compute_100a,code=sm_100a (Blackwell)-DMPK_ENABLE_TMA (Hopper/Blackwell), -DMIRAGE_GRACE_HOPPER or -DMIRAGE_GRACE_BLACKWELL-DMODE_OFFLINE, -DMPK_MAX_NUM_BATCHED_REQUESTS=N, -DMPK_MAX_NUM_BATCHED_TOKENS=N, -DMPK_MAX_NUM_PAGES=N, -DMPK_PAGE_SIZE=N, -DMPK_MAX_SEQ_LENGTH=N-DMAX_WORKER_PER_SCHEDULER=N (computed from worker/scheduler ratio).so) as a Python extension moduleFor multi-GPU (NVSHMEM): adds -rdc=true, NVSHMEM/MPI includes and libraries.
include/mirage/persistent_kernel/persistent_kernel.cuhinit_persistent_kernel() sets up the full runtime state:
Meta-tensor mapping — stores 10 meta-tensor pointers in global_runtime_config (step, tokens, input_tokens, output_tokens, num_new_tokens, prompt_lengths, qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len).
NVSHMEM init (if multi-GPU) — calls nvshmemx_init_attr(), creates NVSHMEM teams for cross-GPU communication.
Call generated _init_persistent_kernel() — this loads the JSON task graph via construct_task_graph(), allocates GPU memory for intermediate tensors, and populates the all_tasks, all_events, first_tasks vectors.
Allocate runtime queues on GPU:
worker_queues[2 * num_workers] — per-worker task queues (local + remote). Each is a circular buffer of TaskId with length per_worker_queue_len (1024).sched_queues[num_schedulers + 1] — per-scheduler event queues + one global broadcast queue. Circular buffers of EventId.worker_queue_last_ready_task_id[2 * num_workers] — atomic counters for queue tail.sched_queue_last_ready_event_id[num_schedulers + 1] — atomic counters for event queue tail.all_event_counters[num_events] — atomic counters tracking how many times each event has been triggered.all_event_num_triggers[num_events] — how many triggers each event needs before it's considered "ready".Copy task/event data to GPU — all_tasks, all_events, first_tasks are copied to device memory.
Set kernel attributes — sets cudaFuncAttributeMaxDynamicSharedMemorySize for worker and scheduler kernels.
Create streams and events — separate CUDA streams for workers and schedulers (split mode), plus synchronization events.
Call init_request_resources() — launches init_kernel which initializes per-request state (step counters, page queues for MODE_OFFLINE/MODE_ONLINE).
include/mirage/persistent_kernel/persistent_kernel.cuhlaunch_persistent_kernel(stream)prepare_kernel<<<>>> — resets all queue pointers and event counters to zero. Seeds the initial EVENT_END_OF_TASK_GRAPH event to scheduler[0], which kicks off the first iteration.
Kernel launch (two modes):
split_worker_scheduler = true): launches worker_kernel and scheduler_kernel as separate kernels on separate streams. Workers get WORKER_NUM_THREADS threads per block; schedulers get 32 threads (1 warp). Synchronized via CUDA events. This is now the default mode.persistent_kernel where blocks [0, num_workers) run execute_worker() and remaining blocks run execute_scheduler().execute_worker()Each worker thread block runs an infinite loop:
Fetch tasks — polls worker_queue_last_ready_task_id[worker_id] using ld_acquire until new tasks appear. Loads a batch of TaskDesc from the queue into shared memory (using cp.async for efficiency).
Wait for dependencies — if task_desc->dependent_event != EVENT_INVALID_ID, polls the event counter all_event_counters[event_index] until it reaches num_triggers * iteration_num. For NVSHMEM events, uses nvshmem_signal_wait_until.
Execute task — calls _execute_task(task_desc, runtime_config) which dispatches to the generated kernel code based on (task_type, variant_id).
Signal completion — atomically increments all_event_counters[trigger_event_index]. If this was the final trigger for that event, enqueues the event to the appropriate scheduler's queue.
Terminate — when a TASK_TERMINATE task is received, the worker returns.
execute_scheduler()Each scheduler runs on a single warp (32 threads, only thread 0 active). Up to 4 schedulers can share one SM (4 warps):
Fetch events — polls sched_queue_last_ready_event_id[sched_id] for new events.
Process event by type:
EVENT_LAUNCH_TASKS / EVENT_LAUNCH_MASSIVE_TASKS: enqueue the task range [first_task_id, last_task_id) to worker queues in round-robin fashion.EVENT_LAUNCH_DEPENDENT_TASKS: similar but increments iteration_num (for cross-iteration dependencies).EVENT_END_OF_TASK_GRAPH: calls prepare_next_batch() to set up the next inference iteration (finalize previous batch, allocate KV cache pages, load new tokens). If prepare_next_batch returns false (no more work), calls terminate_schedulers().TASK_TERMINATE to all workers and returns.Task assignment — each scheduler owns a range of workers (my_first_worker to my_last_worker). It round-robins task assignments within this range, using local counters to track queue positions.
prepare_next_batch() (defined per mode via #ifdef):
next_request_id.The task_graph.json file is the key intermediate artifact between code generation and runtime. Generated by print_task_graph() in runtime.cc, loaded by construct_task_graph() at init time.
The task graph JSON is very large and should never be read in a raw fashion. Always use scripts/parse_task_graph.py to parse and analyze it.
{
"all_tasks": [
{
"task_type": 0, // TaskType enum value
"variant_id": 0, // code variant (same task, different dims)
"inputs": [
{
"base_ptr": "tensor_name", // matches io_configs key
"offset": 0, // byte offset from base
"dims": [128, 4096],
"strides": [4096, 1],
"data_type": 1 // dtype enum
}
],
"outputs": [ /* same structure */ ],
"trigger_event": 65537, // EventId this task signals on completion
"dependent_event": 65536, // EventId this task waits for before executing
"request_id": -1, // task_metadata: which request (-1 = all)
"expert_offset": -1, // task_metadata: MoE expert offset
"kv_idx": -1, // task_metadata: KV cache chunk index
"merge_task_offset": -1, // task_metadata: split-KV merge offset
"task_offset": -1 // task_metadata: NVSHMEM team mapping
}
],
"all_events": [
{
"event_type": 0, // EVENT_TERMINATION, EVENT_LAUNCH_TASKS, etc.
"num_triggers": 1, // how many task completions before this event fires
"first_task_id": 0, // range of tasks this event unlocks
"last_task_id": 4
}
],
"first_tasks": [1, 2, 3] // TaskIds ready to execute immediately
}
Event types (runtime_header.h):
EVENT_TERMINATION (0) — terminate the kernelEVENT_LAUNCH_TASKS (1) — enqueue a range of tasks to one schedulerEVENT_END_OF_TASK_GRAPH (2) — end of one forward pass; triggers prepare_next_batchEVENT_EMPTY (3) — no-opEVENT_LAUNCH_MASSIVE_TASKS (4) — large task range split across all local schedulersEVENT_LAUNCH_DEPENDENT_TASKS (5) — cross-iteration dependent tasksTaskId encoding (64-bit): [iteration_num: upper 32 bits][position_index: lower 32 bits]
EventId encoding (64-bit): [nvshmem_tag: upper bits][gpu_id: middle 16 bits][event_index: lower 32 bits]
RuntimeConfig (runtime_header.h)Global configuration struct stored in GPU global memory. Contains:
num_workers, num_local_schedulers, num_remote_schedulers, num_gpus, my_gpu_idworker_queues[][], sched_queues[][], atomic tail countersall_tasks[], all_events[], all_event_counters[], first_tasks[]step[], tokens[], input_tokens[], output_tokens[], KV cache page management arrayssplit_worker_scheduler, CUDA streams/events for synchronizationFullTaskDesc (runtime_header.h)Full task descriptor used during code generation and JSON serialization. Contains tensor descriptors with shapes/strides, event IDs, and task metadata.
TaskDesc (runtime_header.h)Compact runtime task descriptor (16-byte aligned). Contains only raw pointers (input_ptrs[7], output_ptrs[3]), TMA descriptor pointers (if Hopper/Blackwell), event IDs, and task metadata. Constructed from FullTaskDesc at init time by resolving tensor names to GPU pointers.
TaskDesc::TaskMetadata (union)Per-task metadata packed into 8 bytes. Interpretation depends on task type:
expert_offset (int) — MoE: which expert subset this task handlesrequest_id (int16) + kv_idx (uint16) + merge_task_offset (int) — paged attentiontask_offset (int) — NVSHMEM team index for multi-GPU tasksEventDesc (runtime_header.h)Event descriptor: event_type, num_triggers (how many completions needed), first_task_id/last_task_id (range of tasks this event unlocks).
TensorDesc (runtime_header.h)Tensor metadata for JSON serialization: num_dims, base_ptr (name string at codegen time, resolved to GPU pointer at init), dim[], stride[], data_type, optional TMA descriptor pointers.