원클릭으로
impl-jit-kernel
// Guide for implementing CUDA or CPU JIT kernels in mllm-kernel. Use when the user asks to create, add, or implement a new kernel in mllm-kernel.
// Guide for implementing CUDA or CPU JIT kernels in mllm-kernel. Use when the user asks to create, add, or implement a new kernel in mllm-kernel.
| name | impl-jit-kernel |
| description | Guide for implementing CUDA or CPU JIT kernels in mllm-kernel. Use when the user asks to create, add, or implement a new kernel in mllm-kernel. |
mllm-kernel uses a JIT (Just-In-Time) compilation system built on tvm_ffi. Kernels are written in C++20 (.cuh for CUDA, .cpp for CPU), validated at runtime via TensorMatcher, and exposed to Python through a @jit decorator. No pre-compilation is needed -- kernels compile on first call and are cached at ~/.cache/mllm_kernel/.
For a kernel named my_kernel:
mllm-kernel/
mllm_kernel/
cuda/
csrc/my_kernel.cuh # CUDA kernel implementation
jit/my_kernel.py # Python JIT wrapper
jit/__init__.py # Add export here
cpu/
csrc/my_kernel.cpp # CPU kernel implementation (Highway SIMD)
include/mllm_kernel/cpu/
my_kernel.hpp # CPU SIMD body (NO #pragma once)
jit/my_kernel.py # Python JIT wrapper
jit/__init__.py # Add export here
tests/test_my_kernel.py # Pytest correctness tests
benchmarks/bench_my_kernel.py # Profiler benchmark vs PyTorch reference
.cuh kernelCreate mllm_kernel/cuda/csrc/my_kernel.cuh:
#pragma once
#include <mllm_kernel/tensor.hpp> // TensorMatcher, SymbolicSize, SymbolicDevice, SymbolicDType
#include <mllm_kernel/utils.hpp> // RuntimeCheck, Panic, div_ceil
#include <mllm_kernel/utils.cuh> // LaunchKernel, fp16_t, bf16_t, PDL helpers
#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>
#include <cstdint>
namespace {
// ---------------------------------------------------------------------------
// 1. Parameter struct (trivially copyable, passed to kernel by value)
// ---------------------------------------------------------------------------
struct MyKernelParams {
const float* __restrict__ input;
float* __restrict__ output;
int32_t num_elements;
};
// ---------------------------------------------------------------------------
// 2. CUDA kernel
// ---------------------------------------------------------------------------
__global__ void my_kernel(const MyKernelParams params) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= params.num_elements) return;
params.output[idx] = params.input[idx] * 2.0f;
}
// ---------------------------------------------------------------------------
// 3. Host-side launcher (entry point for TVM FFI binding)
// ---------------------------------------------------------------------------
struct MyKernel {
static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView output) {
using namespace mllm_kernel::host;
// --- Validate tensors ---
SymbolicSize N{"num_elements"};
SymbolicDevice device;
(void)TensorMatcher({N})
.with_dtype<float>()
.with_device<kDLCUDA>(device)
.verify(input);
(void)TensorMatcher({N})
.with_dtype<float>()
.with_device(device)
.verify(output);
const int64_t n = N.unwrap();
RuntimeCheck(n > 0, "num_elements must be positive, got ", n);
// --- Build params ---
MyKernelParams params{
.input = static_cast<const float*>(input.data_ptr()),
.output = static_cast<float*>(output.data_ptr()),
.num_elements = static_cast<int32_t>(n),
};
// --- Launch ---
constexpr int kBlock = 256;
const int grid = static_cast<int>(div_ceil(n, kBlock));
LaunchKernel(grid, kBlock, device.unwrap())(my_kernel, params);
}
};
} // namespace
Key rules:
namespace {} (anonymous namespace).static void run(tvm::ffi::TensorView ...) method.TensorMatcher before reading .data_ptr().data_ptr() returns a GPU pointer.LaunchKernel to launch -- it handles stream resolution and error checking.Create mllm_kernel/cuda/jit/my_kernel.py:
"""JIT wrapper for my_kernel CUDA kernel."""
import torch
from mllm_kernel.jit_utils import jit
@jit(
args=[],
device="cuda",
cuda_files=["my_kernel.cuh"],
cpp_wrappers=[],
cuda_wrappers=[("my_kernel", "MyKernel::run")],
func_name="my_kernel",
)
def _kernel(compiled_module, input: torch.Tensor, output: torch.Tensor) -> None:
compiled_module.my_kernel(input, output)
def my_kernel(input: torch.Tensor) -> torch.Tensor:
"""Double every element in *input*.
Parameters
----------
input : torch.Tensor
1-D float32 tensor on CUDA.
Returns
-------
torch.Tensor
Same shape and dtype as *input*.
"""
output = torch.empty_like(input)
_kernel(input, output)
return output
__init__.pyEdit mllm_kernel/cuda/jit/__init__.py and add:
from mllm_kernel.cuda.jit.my_kernel import my_kernel
.cuhAny time you modify the .cuh file, delete the cached .so:
rm -rf ~/.cache/mllm_kernel/cuda_my_kernel*
The next Python call will trigger recompilation automatically.
When the kernel takes compile-time constants (e.g. block size, dtype), use make_cpp_args:
from mllm_kernel.jit_utils import jit, make_cpp_args
def _make_kernel(block_size: int, use_pdl: bool):
cpp_args = make_cpp_args(block_size, use_pdl) # -> "256, true"
@jit(
args=[block_size, use_pdl],
device="cuda",
cuda_files=["my_kernel.cuh"],
cpp_wrappers=[],
cuda_wrappers=[("my_kernel", f"MyKernel<{cpp_args}>::run")],
func_name="my_kernel",
)
def _kernel(compiled_module, input, output):
compiled_module.my_kernel(input, output)
return _kernel
make_cpp_args converts Python types to C++ literals:
int/float -> string literalbool -> "true" / "false"torch.dtype -> C++ type (torch.float32 -> "fp32_t", torch.float16 -> "fp16_t", torch.bfloat16 -> "bf16_t", torch.int32 -> "int32_t", etc.)CPU kernels use Google Highway for portable SIMD. The key difference: the .hpp body is included multiple times by Highway's foreach_target dispatch, so it must NOT have #pragma once.
.hpp)Create mllm_kernel/cpu/include/mllm_kernel/cpu/my_kernel.hpp:
// NOTE: NO #pragma once -- this file is included multiple times by Highway.
#include <hwy/highway.h>
HWY_BEFORE_NAMESPACE();
namespace mllm_kernel::cpu {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <int Constant>
inline void my_kernel_impl(float* HWY_RESTRICT dst,
const float* HWY_RESTRICT src,
size_t count) {
const hn::ScalableTag<float> d;
const size_t lanes = hn::Lanes(d);
const auto vc = hn::Set(d, static_cast<float>(Constant));
size_t i = 0;
for (; i + lanes <= count; i += lanes) {
const auto v = hn::Load(d, src + i);
hn::Store(hn::Add(v, vc), d, dst + i);
}
for (; i < count; ++i) {
dst[i] = src[i] + static_cast<float>(Constant);
}
}
// Named entry points for HWY_EXPORT
static HWY_NOINLINE HWY_MAYBE_UNUSED void my_kernel_1(float* d, const float* s, size_t n) {
my_kernel_impl<1>(d, s, n);
}
} // namespace HWY_NAMESPACE
} // namespace mllm_kernel::cpu
HWY_AFTER_NAMESPACE();
.cpp sourceCreate mllm_kernel/cpu/csrc/my_kernel.cpp:
#include <mllm_kernel/tensor.hpp>
#include <mllm_kernel/utils.hpp>
#include <tvm/ffi/container/tensor.h>
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "../csrc/my_kernel.cpp"
#include <hwy/foreach_target.h>
#include <mllm_kernel/cpu/my_kernel.hpp>
#if HWY_ONCE
#include <hwy/targets.cc>
#endif
namespace mllm_kernel::cpu {
#if HWY_ONCE
HWY_EXPORT(my_kernel_1);
template <int Constant>
void my_kernel(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {
using namespace mllm_kernel::host;
SymbolicSize N{"num_elements"};
SymbolicDevice device_;
(void)TensorMatcher({N})
.with_dtype<float>()
.with_device<kDLCPU>(device_)
.verify(dst)
.verify(src);
const size_t n = N.unwrap();
auto* dst_ptr = static_cast<float*>(dst.data_ptr());
const auto* src_ptr = static_cast<const float*>(src.data_ptr());
HWY_DYNAMIC_DISPATCH(my_kernel_1)(dst_ptr, src_ptr, n);
}
// Explicit instantiation
template void my_kernel<1>(tvm::ffi::TensorView, tvm::ffi::TensorView);
#endif
} // namespace mllm_kernel::cpu
Create mllm_kernel/cpu/jit/my_kernel.py:
import torch
from mllm_kernel.jit_utils import jit
@jit(
args=1,
device="cpu",
cpp_files=["my_kernel.cpp"],
cpp_wrappers=[("my_kernel", "mllm_kernel::cpu::my_kernel<1>")],
func_name="my_kernel",
)
def _kernel_1(compiled_module, dst, src):
compiled_module.my_kernel(dst, src)
def my_kernel(src: torch.Tensor) -> torch.Tensor:
dst = torch.empty_like(src)
_kernel_1(dst, src)
return dst
Key CPU differences from CUDA:
| Aspect | CUDA | CPU |
|---|---|---|
| Source file | .cuh in cuda/csrc/ | .cpp + .hpp in cpu/csrc/ and cpu/include/ |
| Namespace | Anonymous namespace {} | mllm_kernel::cpu |
| Device check | with_device<kDLCUDA> | with_device<kDLCPU> |
| Launch | LaunchKernel(grid, block, device)(...) | Direct function call via HWY_DYNAMIC_DISPATCH |
| SIMD | CUDA warps | Highway ScalableTag<T> |
| Wrapper fields | cuda_files, cuda_wrappers | cpp_files, cpp_wrappers |
| Wrapper name | "MyKernel::run" | "mllm_kernel::cpu::my_kernel<1>" (fully qualified) |
TensorMatcher validates shape, dtype, device, and strides of tvm::ffi::TensorView arguments.
using namespace mllm_kernel::host;
// Symbolic dimensions -- bind on first .verify(), check consistency on subsequent calls
SymbolicSize B{"batch"}, N{"seq_len"}, D{"dim"};
SymbolicSize Stride0{"stride0"};
SymbolicDType dtype;
SymbolicDevice device;
// Shape [B, N, D], contiguous, float32, on CUDA
(void)TensorMatcher({B, N, D})
.with_dtype<float>(dtype)
.with_device<kDLCUDA>(device)
.verify(tensor_a);
// Shape [B, N, D], same dtype and device (already bound)
(void)TensorMatcher({B, N, D})
.with_dtype(dtype)
.with_device(device)
.verify(tensor_b);
// Shape [B, D] with explicit strides (non-contiguous OK)
(void)TensorMatcher({B, D})
.with_strides({Stride0, 1})
.with_dtype<int32_t>()
.with_device(device)
.verify(indices);
// Multiple acceptable dtypes
SymbolicDType flex_dtype;
(void)TensorMatcher({N})
.with_dtype<float, __half, __nv_bfloat16>(flex_dtype)
.with_device(device)
.verify(mixed_tensor);
// Extract bound values
int64_t batch = B.unwrap();
int64_t dim = D.unwrap();
DLDevice dev = device.unwrap();
using namespace mllm_kernel::host;
// Basic launch (resolves CUDA stream from DLDevice)
DLDevice dev = device.unwrap();
LaunchKernel(grid_dim, block_dim, dev)(kernel_func, param_struct);
// With shared memory
LaunchKernel(grid, block, dev, shared_mem_bytes)(kernel, params);
// With PDL (Programmatic Dependent Launch, sm_90+)
LaunchKernel(grid, block, dev).enable_pdl(true)(kernel, params);
mllm_kernel::host)| Function | Description |
|---|---|
RuntimeCheck(cond, msg...) | Throws PanicError if cond is false |
Panic(msg...) | Always throws (unreachable code) |
div_ceil(a, b) | Integer ceiling division |
dtype_bytes(DLDataType) | Byte size of a DLPack dtype |
CUDA-only (mllm_kernel::device):
| Symbol | Value |
|---|---|
kWarpThreads | 32 |
kFullMask | 0xffffffff |
fp16_t | __half |
bf16_t | __nv_bfloat16 |
Create tests/test_my_kernel.py:
import pytest
import torch
from mllm_kernel.cuda.jit.my_kernel import my_kernel
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@pytest.mark.parametrize("n", [1, 128, 1024, 65536])
def test_my_kernel(n):
x = torch.randn(n, dtype=torch.float32, device="cuda")
result = my_kernel(x)
torch.cuda.synchronize()
expected = x * 2.0
assert torch.allclose(result, expected)
Run:
pytest tests/test_my_kernel.py -v
Create benchmarks/bench_my_kernel.py. Use torch.profiler.profile with ProfilerActivity.CPU and ProfilerActivity.CUDA. Compare the JIT kernel against a naive PyTorch implementation and report speedup.
Run:
python benchmarks/bench_my_kernel.py --num-elements 1000000
.cuh / .cpp + .hpp kernel source createdTensorMatcher validates all tensor arguments (shape, dtype, device)@jit wrapper created with correct cuda_wrappers or cpp_wrappers_kernel)jit/__init__.py.cuh edits (rm -rf ~/.cache/mllm_kernel/cuda_<name>*)@pytest.mark.parametrize and PyTorch referencetorch.profiler (optional but recommended)tensor.data_ptr() returns a GPU pointer for CUDA tensors. Never read its contents in host code. Use TensorMatcher for validation instead..cuh, delete ~/.cache/mllm_kernel/cuda_<kernel_name>*/. The old .so will be reused otherwise.#include <hwy/targets.cc> -- CPU kernels must include this inside #if HWY_ONCE to provide GetChosenTarget for the JIT-built module.#pragma once in Highway .hpp -- Highway's foreach_target includes the file multiple times for different SIMD targets. #pragma once breaks this."MyKernel::run"); CPU uses fully qualified names ("mllm_kernel::cpu::my_kernel<1>").torch.randperm needs a CUDA generator on CUDA; torch.randint only accepts CPU generators. Use separate generators.Install the pymllm Python package. Asks the user whether to do a full build (with CMake C++ compilation) or a fast install (Python-only, skip CMake). Use when the user asks to install, set up, or reinstall pymllm.
Create or update the pymllm/lib symlink to point to a C++ build directory's bin/ folder. Required after editable installs with C++ builds so that Python can find the compiled .so libraries. Use when the user asks to link, fix, or set up pymllm native libraries.
Updates CODEOWNERS entries safely with consistent path and owner formatting. Use when the user asks to add, remove, or modify CODEOWNERS rules, ownership mappings, reviewers, or module maintainers.