| name | token-embeddings |
| description | Create and work with token embeddings for LLMs. Use this skill whenever you need to understand token embeddings, create embedding layers in PyTorch, add positional embeddings (absolute, relative, or RoPE), or debug embedding-related issues in your language model. This skill covers vocabulary setup, embedding initialization, positional encoding strategies, and context window extension techniques. Make sure to use this skill when working with any LLM architecture, training pipelines, or when you need to convert tokens to numerical vectors. |
Token Embeddings Skill
This skill helps you create, understand, and work with token embeddings for large language models.
What This Skill Does
- Explains token embedding concepts and initialization
- Creates PyTorch embedding layers for your vocabulary
- Adds positional embeddings (absolute, relative, RoPE)
- Helps debug embedding-related issues
- Provides code templates for common embedding tasks
When to Use This Skill
Use this skill when you:
- Need to create token embeddings for a new vocabulary
- Want to understand how embeddings work in your model
- Need to add positional information to your embeddings
- Are debugging embedding dimension mismatches
- Want to extend context windows in RoPE-based models
- Need to implement or understand different positional encoding strategies
Core Concepts
Token Embeddings
Token embeddings convert discrete tokens into continuous vectors. Each token in your vocabulary gets a unique vector of fixed dimensions.
Key parameters:
vocab_size: Number of unique tokens (e.g., 50257 for BPE)
embedding_dim: Vector dimensions (e.g., 256, 512, 768)
Example:
Vocabulary: [1, 2, 3, 4, 5, 6] (6 tokens)
Embedding dim: 3
Token 3 → [-0.4015, 0.9666, -1.1481]
Positional Embeddings
Positional embeddings encode token positions in sequences. Without them, the model treats tokens as a "bag of words."
Types:
- Absolute: Fixed position vectors (GPT-style)
- Relative: Distance-based encoding (Transformer-XL, BERT variants)
- RoPE: Rotary embeddings (modern decoder-only LLMs)
Quick Start
Create Basic Token Embeddings
import torch
vocab_size = 50257
embedding_dim = 256
token_embedding = torch.nn.Embedding(vocab_size, embedding_dim)
Add Absolute Positional Embeddings
context_length = 512
pos_embedding = torch.nn.Embedding(context_length, embedding_dim)
token_emb = token_embedding(token_ids)
pos_emb = pos_embedding(torch.arange(seq_len))
combined = token_emb + pos_emb
RoPE (Rotary Positional Embeddings)
For modern LLMs, RoPE is preferred:
def apply_rope(q, k, cos, sin):
"""Apply rotary positional embeddings to query/key vectors."""
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Common Tasks
Task 1: Initialize Embedding Layer
import torch
def create_token_embeddings(vocab_size: int, embedding_dim: int) -> torch.nn.Embedding:
"""Create a token embedding layer."""
return torch.nn.Embedding(vocab_size, embedding_dim)
embedding_layer = create_token_embeddings(50257, 256)
print(embedding_layer.weight.shape)
Task 2: Create Positional Embeddings
def create_positional_embeddings(context_length: int, embedding_dim: int) -> torch.nn.Embedding:
"""Create absolute positional embeddings."""
return torch.nn.Embedding(context_length, embedding_dim)
pos_layer = create_positional_embeddings(512, 256)
pos_embeddings = pos_layer(torch.arange(512))
print(pos_embeddings.shape)
Task 3: Combine Token and Positional Embeddings
def combine_embeddings(
token_ids: torch.Tensor,
token_embedding: torch.nn.Embedding,
pos_embedding: torch.nn.Embedding
) -> torch.Tensor:
"""Combine token and positional embeddings.
Args:
token_ids: [batch_size, seq_len]
token_embedding: Token embedding layer
pos_embedding: Positional embedding layer
Returns:
Combined embeddings: [batch_size, seq_len, embedding_dim]
"""
batch_size, seq_len = token_ids.shape
token_emb = token_embedding(token_ids)
positions = torch.arange(seq_len).expand(batch_size, -1)
pos_emb = pos_embedding(positions)
return token_emb + pos_emb
Task 4: Position Interpolation for Extended Context
def position_interpolation(
pos_ids: torch.Tensor,
original_context: int,
new_context: int
) -> torch.Tensor:
"""Scale position indices for context window extension.
Args:
pos_ids: Original position indices
original_context: Training context length (e.g., 2048)
new_context: Target context length (e.g., 8192)
Returns:
Scaled position indices
"""
scale = original_context / new_context
scaled_pos = (pos_ids * scale).long()
return scaled_pos
original_ctx = 2048
new_ctx = 8192
positions = torch.arange(8192)
scaled_positions = position_interpolation(positions, original_ctx, new_ctx)
Debugging Checklist
When embeddings aren't working correctly, check:
-
Dimension Mismatches
assert token_emb.shape == pos_emb.shape, "Embedding dimensions must match"
-
Vocabulary Size
max_token_id = token_ids.max()
assert max_token_id < vocab_size, f"Token {max_token_id} exceeds vocab_size {vocab_size}"
-
Context Length
seq_len = token_ids.shape[1]
assert seq_len <= context_length, f"Sequence {seq_len} exceeds context {context_length}"
-
Gradient Flow
assert token_embedding.weight.requires_grad, "Embeddings should be trainable"
Best Practices
- Embedding Dimensions: Use powers of 2 (256, 512, 768, 1024) for efficiency
- Initialization: PyTorch's default Xavier initialization works well
- Positional Encoding: Use RoPE for decoder-only models, absolute for encoder-only
- Context Extension: Use position interpolation before fine-tuning for longer contexts
- Batch Processing: Always process in batches for efficiency
Example: Complete Embedding Setup
import torch
import torch.nn as nn
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int, context_length: int):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
self.pos_embedding = nn.Embedding(context_length, embedding_dim)
self.context_length = context_length
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
batch_size, seq_len = token_ids.shape
token_emb = self.token_embedding(token_ids)
positions = torch.arange(seq_len).expand(batch_size, -1)
pos_emb = self.pos_embedding(positions)
return token_emb + pos_emb
vocab_size = 50257
embedding_dim = 256
context_length = 512
embedding_model = TokenEmbedding(vocab_size, embedding_dim, context_length)
batch_size = 8
seq_len = 4
token_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
output = embedding_model(token_ids)
print(output.shape)
References