| name | docstring |
| description | Write docstrings for TorchRec functions and methods following PyTorch conventions. Use when writing or updating docstrings in TorchRec code. |
TorchRec Docstring Writing Guide
This skill describes how to write docstrings for functions and methods in the TorchRec project, following PyTorch conventions.
General Principles
- Use raw strings (
r"""...""") for all docstrings to avoid issues with LaTeX/math backslashes
- Follow Sphinx/reStructuredText (reST) format for documentation
- Be concise but complete - include all essential information
- Always include examples when possible
- Use cross-references to related functions/classes
Docstring Structure
1. Function Signature (First Line)
Start with the function signature showing all parameters:
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
Notes:
- Include the function name
- Show positional and keyword-only arguments (use
* separator)
- Include default values
- Show return type annotation
- This line should NOT end with a period
2. Brief Description
Provide a one-line description of what the function does:
r"""apply_optimizer_in_backward(optimizer_class, params, optimizer_kwargs) -> None
Applies optimizer to parameters in backward pass for memory efficiency.
3. Mathematical Formulas (if applicable)
Use Sphinx math directives for mathematical expressions:
.. math::
\text{output} = \text{input} \cdot \text{weight}^T
Or inline math: :math:\x^2``
4. Cross-References
Link to related classes and functions using Sphinx roles:
:class:\~torchrec.modules.EmbeddingBagCollection`` - Link to a class
:func:\torchrec.distributed.sharding.shard`` - Link to a function
:meth:\~Module.forward`` - Link to a method
:attr:\attribute_name`` - Reference an attribute
- The
~ prefix shows only the last component
Example:
See :class:`~torchrec.distributed.DistributedModelParallel` for details.
5. Notes and Warnings
Use admonitions for important information:
.. note::
This function requires CUDA to be available.
.. warning::
This API is experimental and may change without notice.
6. Args Section
Document all parameters with type annotations and descriptions:
Args:
module (nn.Module): Module to be sharded across devices.
device (torch.device, optional): Device to place the module. Default: ``None``
sharders (List[ModuleSharder], optional): List of sharders to use for sharding.
Default: ``None``
plan (ShardingPlan, optional): Explicit sharding plan. If not provided, will be
generated automatically. Default: ``None``
Formatting rules:
- Parameter name in lowercase
- Type in parentheses:
(Type), (Type, optional) for optional parameters
- Description follows the type
- For optional parameters, include "Default:
value" at the end
- Use double backticks for inline code:
``None``
- Indent continuation lines by 4 spaces
7. Returns Section
Document the return value:
Returns:
ShardedModule: The sharded module ready for distributed training.
The module will have its parameters distributed according to
the sharding plan.
8. Raises Section (if applicable)
Document exceptions that may be raised:
Raises:
ValueError: If the sharding plan is invalid for the given module.
RuntimeError: If CUDA is not available when GPU sharding is requested.
9. Examples Section
Always include examples when possible:
Examples::
>>> import torchrec
>>> from torchrec.modules import EmbeddingBagCollection
>>> ebc = EmbeddingBagCollection(
... tables=[
... EmbeddingBagConfig(
... name="product",
... embedding_dim=64,
... num_embeddings=1000,
... feature_names=["product_id"],
... ),
... ],
... )
>>>
>>> sharded_ebc = shard(ebc, plan=plan)
Formatting rules:
- Use
Examples:: with double colon
- Use
>>> prompt for Python code
- Include comments with
# when helpful
- Show actual output when it helps understanding
TorchRec-Specific Patterns
Embedding Configuration
Args:
tables (List[EmbeddingBagConfig]): List of embedding table configurations.
Each config specifies the table name, embedding dimension, number of
embeddings, and feature names.
device (Optional[torch.device]): Device to place embeddings. Default: ``None``
Sharding-Related
Args:
sharding_type (ShardingType): How to shard the embedding table. Options are:
- ``TABLE_WISE``: Each table on a single device
- ``ROW_WISE``: Rows distributed across devices
- ``COLUMN_WISE``: Columns distributed across devices
- ``TABLE_ROW_WISE``: Combination of table and row sharding
KeyedJaggedTensor
Args:
kjt (KeyedJaggedTensor): Sparse features in KeyedJaggedTensor format.
Contains keys (feature names), values (embedding indices), and
lengths/offsets for variable-length sequences.
Complete Example
def shard_modules(
module: nn.Module,
plan: ShardingPlan,
env: ShardingEnv,
device: Optional[torch.device] = None,
) -> nn.Module:
r"""
Shard a module's embedding tables according to a sharding plan.
This function takes a module containing embedding tables and distributes
them across multiple devices according to the provided sharding plan.
It supports various sharding strategies including table-wise, row-wise,
and column-wise sharding.
Args:
module (nn.Module): The module containing embedding tables to shard.
plan (ShardingPlan): The sharding plan specifying how each table
should be distributed.
env (ShardingEnv): The sharding environment containing process group
information and device topology.
device (torch.device, optional): Target device for local shards.
Default: ``None`` (uses current device)
Returns:
nn.Module: The sharded module with distributed embedding tables.
Raises:
ValueError: If the plan references tables not present in the module.
RuntimeError: If the sharding environment is not properly initialized.
.. note::
This function modifies the module in-place for efficiency.
.. warning::
This is an experimental API and may change in future releases.
Examples::
>>> from torchrec.distributed import shard_modules
>>> from torchrec.distributed.planner import EmbeddingShardingPlanner
>>>
>>> # Create a sharding plan
>>> planner = EmbeddingShardingPlanner()
>>> plan = planner.plan(module, sharders)
>>>
>>> # Shard the module
>>> sharded_module = shard_modules(module, plan, env)
"""
Quick Checklist
When writing a TorchRec docstring, ensure: