en un clic
collect-workloads
// Auto-collect workloads from SGLang inference runs using FlashInfer logging API. Dumps tensors, sanitizes them according to kernel definitions, and submits PR to flashinfer-trace workload repo.
// Auto-collect workloads from SGLang inference runs using FlashInfer logging API. Dumps tensors, sanitizes them according to kernel definitions, and submits PR to flashinfer-trace workload repo.
Discover candidate LLMs and produce a kernel inventory — required definitions, classified as existing/new and fi_supported/fi_missing — for onboarding. Use as Phase 1 of /onboard-model, or standalone to plan onboarding work.
Generate Definition JSON files for the flashinfer-trace HuggingFace dataset by harvesting them from a short SGLang inference pass (FlashInfer's @flashinfer_api(trace=...) dumper) — or, as a fallback, by manually transcribing the schema from SGLang sources when FlashInfer doesn't yet have a trace template. Use when adding a new model, extracting GPU kernels (MLA, MoE, GQA, RMSNorm, GEMM, GDN, RoPE, sampling), or filling gaps in the dataset.
End-to-end pipeline for discovering new LLMs with novel kernels and onboarding them into FlashInfer-Bench. Orchestrates repo updates, model discovery, kernel definition generation, workload collection, and PR submission.
Add pytest tests to validate reference implementations in the flashinfer-trace HuggingFace dataset against FlashInfer or SGLang ground truth. Use when validating kernel definitions, adding tests for new op_types, or verifying reference implementations are correct.
Clone SGLang, FlashInfer, sgl-cookbook, and flashinfer-trace repositories to tmp/. Use when setting up the project, preparing for kernel extraction, or when the user needs the source repositories.
Open the per-definition pair of PRs that publishes a model onboarding — PR 2 to the HuggingFace flashinfer-trace dataset (definition + reference test + baseline solution + workloads + blobs + eval traces) and PR 1 to flashinfer-bench (docs/model_coverage.mdx update only). Use as Phase 4 of /onboard-model.
| name | collect-workloads |
| description | Auto-collect workloads from SGLang inference runs using FlashInfer logging API. Dumps tensors, sanitizes them according to kernel definitions, and submits PR to flashinfer-trace workload repo. |
Collect real-world workloads by running SGLang inference with FlashInfer Level 10 logging, then sanitize and submit to the flashinfer-ai/flashinfer-trace HuggingFace dataset.
No code changes to SGLang or FlashInfer are required. Collection works entirely through FlashInfer's built-in logging API.
| Script | Purpose |
|---|---|
scripts/collect_stream.py | Preferred end-to-end streaming script: per-batch-size inference → sanitize → incremental HF push → eval → trace upload |
scripts/collect_workloads.py | Older entry point: runs SGLang inference + sanitizes dumps (collects all batch sizes then pushes once) |
scripts/sanitize_dumps.py | Converts FlashInfer Level 10 dump dirs → JSONL + safetensors; supports --max-new-workloads N for streaming mode |
# Must run under gpu-lock so CUDA_VISIBLE_DEVICES is set
tools/gpu-lock --gpus 8 --exec-timeout 10800 -- \
python3 scripts/collect_stream.py \
--def-name gqa_paged_decode_h5_kv1_d128_ps64 \
--model-key llama-4-scout-ps64 \
--model-path /path/to/model \
--batch-sizes 64 128 \
--pr-num 263 \
[--peer-node-addr nvl72089-T16] \
[--trace-dir tmp/flashinfer-trace]
# Ragged prefill — add disable-radix-cache + disable-piecewise-cuda-graph
tools/gpu-lock --gpus 8 --exec-timeout 10800 -- \
python3 scripts/collect_stream.py \
--def-name gqa_ragged_prefill_causal_h5_kv1_d128 \
--model-key llama-4-scout \
--model-path /path/to/model \
--batch-sizes 64 128 \
--pr-num 265 \
--extra-server-flag --disable-radix-cache --disable-piecewise-cuda-graph
# Paged prefill — add enable-deterministic-inference
tools/gpu-lock --gpus 8 --exec-timeout 10800 -- \
python3 scripts/collect_stream.py \
--def-name gqa_paged_prefill_causal_h5_kv1_d128_ps64 \
--model-key llama-4-scout-ps64 \
--model-path /path/to/model \
--batch-sizes 64 128 \
--pr-num 264 \
--extra-server-flag --disable-cuda-graph --enable-deterministic-inference
Streaming workflow per batch size:
bench_serving.py with DUMP_MAX_COUNT=500 (exhausted in round 1 of 2)sanitize_dumps.py --max-new-workloads 4 — appends 4 diverse workloadsrm -rf dump dirAfter all batch sizes: flashinfer-bench run eval → push trace → PR2 done.
Key flags:
--dump-count 500 (default) — budget per server session--workloads-per-batch 4 (default) — workloads added per batch size--num-batches 2 (default) — inference rounds; budget typically hit in round 1--no-eval — skip eval+trace push (useful when flashinfer-bench is unavailable)--no-push — dry run: collect and sanitize without uploading--replace-first — replace instead of append on first batch sizeAuto-detection from definition tags:
tp:N → sets --tp N (use CUDA_VISIBLE_DEVICES=0,0 to simulate TP=2 on 1 GPU)page_size const axis → sets --page-size Ngit -C tmp/flashinfer pull && git -C tmp/sglang pull
conda run -n flashinfer_bench pip install -e tmp/flashinfer --no-build-isolation
conda run -n flashinfer_bench pip install -e "tmp/sglang/python[all]"
--definitions <name> [name ...]: specific definitions by name--op-type <type>: all definitions under definitions/{op_type}/--all: all definitions in the repoParses fi_api:<dotted.api.name> tags from each definition to build FLASHINFER_DUMP_INCLUDE:
BatchDecodeWithPagedKVCacheWrapper) → include .run, and .plan if the definition has int32/int64 inputsBatchPrefillWithRaggedKVCacheWrapper: SGLang calls .forward()/.forward_return_lse() (not .run()) — those are automatically added to FLASHINFER_DUMP_INCLUDE for Ragged wrappersrmsnorm) → include by function nameKey env vars set automatically:
FLASHINFER_LOGLEVEL=10
FLASHINFER_DUMP_DIR=./workload_dumps_<timestamp>
FLASHINFER_DUMP_SAFETENSORS=1
FLASHINFER_DUMP_INCLUDE=<fi_api patterns> # only log matching API calls
FLASHINFER_DUMP_EXCLUDE=*.__init__
FLASHINFER_DUMP_MAX_COUNT=500 # ~4 batches × 16 layers × 8 TP ranks per session
FLASHINFER_DUMP_MAX_SIZE_GB=30
DUMP_MAX_COUNT sizing: with --restart-per-batch-size, each server session independently counts toward DUMP_MAX_COUNT. 500 covers ~4 full forward passes for TP=8, 16-layer models (4 × 16 × 8 = 512 run() calls). Use 500 as the standard value when collecting per-batch-size.
Setting FLASHINFER_TRACE_DUMP=1 and FLASHINFER_TRACE_DUMP_DIR=<dir> alongside the
logging vars above tells FlashInfer to write a Definition JSON for every
@flashinfer_api(trace=...)-decorated call (one file per unique (op, shape)). This means
one SGLang run can produce both workload tensors and definition JSONs, which is the
fastest way to pick up a shape that turned out to be missing from the dataset (typical
case: a new page-size variant or a quant-config variant).
export FLASHINFER_TRACE_DUMP=1
export FLASHINFER_TRACE_DUMP_DIR=tmp/dumps/fi_trace_{def_name}
# ...your existing FLASHINFER_LOGLEVEL/FLASHINFER_DUMP_* vars stay unchanged...
After the run, stage the new JSONs into the dataset with the snippet from
/extract-kernel-definitions Path A3,
then normalize them for the validator with the
A3b
fix-up snippet (the dumper's def _xxx_reference and dtype: "unknown" need
patching).
The trace dump is independent of the logging API and adds negligible overhead, so it's
safe to leave on for any collection run. Skip it only when the definitions are already
known-correct and stable.
Inference source: synthetic random prompts (default, --dataset random) or real ShareGPT prompts (--dataset sharegpt).
random (default): generates token-id prompts of a chosen length via sample_random_requests (ported from InferenceX utils/bench_serving/benchmark_serving.py). Use when you need controlled prefill length and a guaranteed decode budget. Each request decodes for exactly --osl tokens because ignore_eos=True is on by default.
--isl 1024 --osl 1024 (decode-heavy, big-batch decode shapes) and --isl 8192 --osl 1024 (prefill-heavy).--random-range-ratio jitters lengths uniformly in [ratio*len, len]. Leave at 1.0 for exact lengths.sharegpt: real prompts from anon8231489123/ShareGPT_Vicuna_unfiltered. Length distribution is uncontrolled; use only when prompt realism matters more than coverage.Batch sizes: [8, 32, 64, 128] — powers of 2 matching SGLang CUDA graph capture points, run multiple rounds each for KV-length diversity.
Dispatch: sustained inflight (matches InferenceX --request-rate inf). sglang's async semaphore caps concurrent requests at batch_size and backfills on each completion, so new prefills overlap with ongoing decodes — yielding mixed prefill+decode batches and varied intra-batch kv_lens.
Per-batch-size isolation (--restart-per-batch-size): pass this flag to bench_serving.py when using FLASHINFER_DUMP_MAX_COUNT. Without it, the first batch size exhausts the dump budget (DUMP_MAX_COUNT is a global counter per server process) and later batch sizes capture nothing. With it, each batch size gets its own server session and therefore its own fresh counter.
Standard collection invocation with isolation:
FLASHINFER_DUMP_MAX_COUNT=500 \
FLASHINFER_DUMP_INCLUDE="BatchDecodeWithPagedKVCacheWrapper*" \
FLASHINFER_DUMP_EXCLUDE="*.__init__" \
... \
python3 examples/sglang_bench/bench_serving.py \
--model <model-key> \
--model-path /path/to/model \
--dataset random --isl 1024 --osl 1024 \
--batch-sizes 64 128 \
--num-batches 4 \
--restart-per-batch-size \
--disable-cuda-graph
For prefill-heavy coverage, run a second pass with --isl 8192 --osl 1024. For ShareGPT prompts pass --dataset sharegpt (no --isl/--osl needed).
Three execution modes (chosen automatically based on definition type):
| Mode | When | How |
|---|---|---|
| SGLang offline Engine | Decode-only definitions | engine.generate() with exact batch size per call — guarantees decode sees B concurrent sequences |
| SGLang HTTP server (paged) | Paged-prefill definitions | Launches server with --enable-deterministic-inference to force use_ragged=False, sends prefix-sharing requests via /v1/chat/completions |
| SGLang HTTP server (ragged) | Ragged-prefill definitions (BatchPrefillWithRaggedKVCacheWrapper) | Launches server with --disable-piecewise-cuda-graph (no --enable-deterministic-inference), sends requests with max_tokens=1 |
Critical ragged prefill flags: --disable-cuda-graph alone is insufficient. SGLang always captures a piecewise CUDA graph for prefill; during capture is_in_piecewise_cuda_graph()=True forces use_ragged=False, so the captured graph only uses BatchPrefillWithPagedKVCacheWrapper. Adding --disable-piecewise-cuda-graph prevents the capture, ensuring every prefill executes eagerly through BatchPrefillWithRaggedKVCacheWrapper. Do not add --enable-deterministic-inference for ragged — it forces use_ragged=False entirely.
sanitize_dumps.py processes dump dirs:
fi_api function nameplan() dumps with the following run() dump (same PID) to get structural tensorspaged_kv_indptr→kv_indptr, paged_kv_indices→kv_indices, etc.int32/int64 (structural: indptrs, indices) → saved to safetensors blobq, k_cache, v_cache) → {"type": "random"} (shapes validated but values irrelevant for benchmarking)sm_scale) → {"type": "scalar", "value": <float>}kv_indices to kv_indptr[-1] (SGLang over-allocates KV pool)Runs the baseline solution against collected workloads before PR submission:
flashinfer-bench run --local {trace_dir} --definitions {def_name} --solutions baseline
# → writes {trace_dir}/traces/{def_name}_baseline.jsonl
All entries must have evaluation.status == "PASSED". If any fail, do not submit PR 2.
One HuggingFace PR per definition. PR 1 (GitHub flashinfer-bench) must already be open.
PR 2 contents:
solutions/baseline/{op_type}/{def_name}/flashinfer_wrapper_*.json — FlashInfer API wrapper (calls BatchDecodeWithPagedKVCacheWrapper or BatchPrefillWithPagedKVCacheWrapper, not reference_impl)workloads/{op_type}/{def_name}.jsonlblob/workloads/{op_type}/{def_name}/*.safetensorsdefinitions/{op_type}/{def_name}.json (copied from PR 1)tests/references/test_{def_name}.py (copied from PR 1)traces/{op_type}/{def_name}.jsonl (baseline eval trace, all PASSED)PR description must include the full stdout of collect_workloads.py sglang under ## SGLang Collection Log. The log must show real ShareGPT inference with diverse (batch_size, kv_length) pairs — uniform tiny KV caches (e.g. batch_size=4096 with 1-page contexts) indicate synthetic data, not real inference.
{flashinfer_trace_dir}/workloads/{op_type}/{def_name}.jsonl
{flashinfer_trace_dir}/blob/workloads/{op_type}/{def_name}/{def_name}_{uuid}.safetensors
Each JSONL line:
{
"definition": "gqa_paged_decode_h32_kv8_d128_ps1",
"workload": {
"uuid": "a1b2c3d4-...",
"axes": {"len_indptr": 33, "num_kv_indices": 4096},
"inputs": {
"q": {"type": "random"},
"k_cache": {"type": "random"},
"v_cache": {"type": "random"},
"kv_indptr": {"type": "safetensors", "path": "...", "tensor_key": "kv_indptr"},
"kv_indices": {"type": "safetensors", "path": "...", "tensor_key": "kv_indices"},
"kv_last_page_len": {"type": "safetensors", "path": "...", "tensor_key": "kv_last_page_len"},
"sm_scale": {"type": "scalar", "value": 0.0883}
}
}
}
No tensor dumps generated: verify FLASHINFER_LOGLEVEL=10 is set before any FlashInfer import; check FLASHINFER_DUMP_INCLUDE matches actual API function names; confirm --attention-backend flashinfer.
run() not captured: look for cudaErrorStreamCaptureUnsupported in SGLang log. Fix: always pass both --disable-cuda-graph and --disable-piecewise-cuda-graph.
Ragged prefill yields 0 workloads: two possible causes — (1) --enable-deterministic-inference is set, which forces use_ragged=False globally — never set this for ragged definitions; (2) piecewise CUDA graph is active (default even without --enable-deterministic-inference), so is_in_piecewise_cuda_graph()=True during capture forces use_ragged=False, and the cached graph always routes to BatchPrefillWithPagedKVCacheWrapper. Fix: add --disable-piecewise-cuda-graph. The script auto-detects ragged prefill definitions and adds this flag automatically.
Constant axis mismatch across TP: use --skip-const-axis-check when collecting TP=1 dumps for a TP=2 definition (structural tensors are identical across TP).
SGLang not wired to target FlashInfer API: grep for the fi_api function name in tmp/sglang/python/sglang/srt/. If missing, the onboard-model skill handles submitting a SGLang PR to wire it in.
Run /clone-repos first. Then:
/clone-repos
/extract-kernel-definitions --model-name <model>
/collect-workloads --op-type <op_type> --model-path /path/to/model