원클릭으로
test-gen
// Generate tests for TorchRec source files with correct patterns (unit, distributed, hypothesis), proper BUCK targets, and test utilities. Use when asked to generate tests, add test coverage, or write tests for a module.
// Generate tests for TorchRec source files with correct patterns (unit, distributed, hypothesis), proper BUCK targets, and test utilities. Use when asked to generate tests, add test coverage, or write tests for a module.
Review TorchRec pull requests and diffs for distributed correctness, sharding safety, backward compatibility, and test coverage. Use when reviewing PRs, diffs, or when asked to review code changes.
Investigate and explain TorchRec planner sharding statistics output, especially how HBM storage is computed per table and per rank. Use when the user asks about sharding stats, storage breakdown, or memory estimation.
Write docstrings for TorchRec functions and methods following PyTorch conventions. Use when writing or updating docstrings in TorchRec code.
Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill for TorchRec, or needs help with SKILL.md files.
Interview user in-depth to create a detailed spec with strict implementation details and tradeoff analysis
Find and remove tech debt (redundant/duplicated code), run linters, and ensure code quality in recent changes
| name | test-gen |
| argument-hint | ["file-path or \"local\""] |
| description | Generate tests for TorchRec source files with correct patterns (unit, distributed, hypothesis), proper BUCK targets, and test utilities. Use when asked to generate tests, add test coverage, or write tests for a module. |
| allowed-tools | Read, Write, Edit, Bash(sl:*), Bash(buck2:*), Grep, Glob, Task |
Generate idiomatic TorchRec tests by reading source files, detecting the appropriate test type, scaffolding test code with correct patterns, and creating/updating BUCK targets.
/test-gen torchrec/distributed/sharding/my_sharder.py
/test-gen torchrec/modules/new_module.py
Generate tests for the specified source file.
/test-gen local
/test-gen
Detect changed files via sl status and generate tests for new/modified source files that lack test coverage.
File path mode: Read the specified file.
Auto-detect mode:
sl status to find changed files.py source files in torchrec/ (exclude test files, __init__.py, BUCK files)$(dirname)/test(s)/test_$(basename)Read the source file and classify it:
Detection rules (in priority order):
Distributed test if ANY of:
torchrec/distributed/torch.distributed, torchrec.distributed, or uses ProcessGroupShardingTypeLazyAwaitable, all_to_all, all_reduce, all_gatherHypothesis-parameterized test if ANY of:
ShardingType or EmbeddingComputeKernel valuesUnit test (default) if:
torchrec/modules/, torchrec/sparse/, torchrec/optim/, torchrec/metrics/A file can be both distributed AND hypothesis-parameterized.
Extract from the source file:
Follow TorchRec convention:
torchrec/foo/bar/my_module.pytorchrec/foo/bar/tests/test_my_module.pyIf a tests/ directory doesn't exist, create it.
If a test file already exists, add new test methods rather than overwriting.
Generate tests following the patterns below. See test-patterns.md for complete templates.
For all test types:
# pyre-strict-> None for test methods)self.assertEqual, self.assertTrue, torch.testing.assert_close for assertionstest_<what>_<condition>For unit tests:
unittest.TestCaseforward() with representative inputs, verify output shapes and typesFor distributed tests:
MultiProcessTestBase@staticmethod or module-level _test_func(rank, world_size, **kwargs) patternwith MultiProcessContext(rank, world_size, backend) as ctx:world_size=2, add world_size=4 for sharding testsbackend="gloo" unless testing GPU-specific behavior@unittest.skipIf(torch.cuda.device_count() < N, "Not enough GPUs...") for CUDA testsFor hypothesis tests:
@given(...) with st.sampled_from([...]) for enum/config parameters@settings(verbosity=Verbosity.verbose, max_examples=N, deadline=None)assume() to filter invalid parameter combinationsmax_examples reasonable (4-8 for distributed tests, 10-20 for unit tests)Read the existing BUCK file in the tests/ directory (or create one if it doesn't exist).
For CPU-only unit tests:
python_unittest(
name = "test_my_module",
srcs = ["test_my_module.py"],
deps = [
"//caffe2:_torch",
# ... source deps ...
],
)
For GPU/distributed tests:
python_unittest(
name = "test_my_module",
srcs = ["test_my_module.py"],
remote_execution = re_test_utils.remote_execution(
mig = "false",
platform = "gpu-remote-execution",
resource_units = 2,
),
deps = [
"//caffe2:_torch",
"//torchrec/distributed/test_utils:multi_process",
# ... source deps ...
],
)
If hypothesis is used, add:
supports_static_listing = False,
and add to deps:
"fbsource//third-party/pypi/hypothesis:hypothesis",
BUCK rules:
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") for standard testsload("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") for GPU testsoncall("torchrec") if already present in the BUCK filetorchrec.* import to its BUCK target by checking the source directory's BUCK filebuck2 test fbcode//torchrec/path/to/tests:test_my_module
buck2 test fbcode//torchrec/path/to/tests:test_my_module -- -s
Use these utilities when generating tests:
| Utility | Import | When to Use |
|---|---|---|
MultiProcessTestBase | torchrec.distributed.test_utils.multi_process | All distributed tests |
MultiProcessContext | torchrec.distributed.test_utils.multi_process | Per-rank setup/teardown |
ModelInput | torchrec.distributed.test_utils.test_model | Generating test inputs for models |
TestSparseNN | torchrec.distributed.test_utils.test_model | Test model with embedding tables |
sharding_single_rank_test | torchrec.distributed.test_utils.test_sharding | Testing sharders |
create_test_sharder | torchrec.distributed.test_utils.test_sharding | Creating test sharder instances |
skip_if_asan_class | torchrec.test_utils | Skip entire class under ASAN |
seed_and_log | torchrec.test_utils | Deterministic seeding with logging |
get_free_port | torchrec.test_utils | Getting available port for dist init |
_) unless they contain complex logic that's critical to test.list[str] vs List[str]).MultiProcessContext + real gloo PG for distributed tests, ScopedConfigeratorFake / JK overrides for config and feature flags, and in-memory fakes where they exist. Reach for mock.patch / MagicMock only when no real fake exists for the dependency, and call out why in a one-line comment.$ARGUMENTS