| name | segment-anything-model |
| description | SAM: zero-shot image segmentation via points, boxes, masks. |
| version | 1.0.0 |
| author | Orchestra Research |
| license | MIT |
| dependencies | ["segment-anything","transformers>=4.30.0","torch>=1.7.0"] |
| platforms | ["linux","macos","windows"] |
| metadata | {"hermes":{"tags":["Multimodal","Image Segmentation","Computer Vision","SAM","Zero-Shot"]}} |
Segment Anything Model (SAM)
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
When to use SAM
Use SAM when:
- Need to segment any object in images without task-specific training
- Building interactive annotation tools with point/box prompts
- Generating training data for other vision models
- Need zero-shot transfer to new image domains
- Building object detection/segmentation pipelines
- Processing medical, satellite, or domain-specific images
Key features:
- Zero-shot segmentation: Works on any image domain without fine-tuning
- Flexible prompts: Points, bounding boxes, or previous masks
- Automatic segmentation: Generate all object masks automatically
- High quality: Trained on 1.1 billion masks from 11 million images
- Multiple model sizes: ViT-B (fastest), ViT-L, ViT-H (most accurate)
- ONNX export: Deploy in browsers and edge devices
Use alternatives instead:
- YOLO/Detectron2: For real-time object detection with classes
- Mask2Former: For semantic/panoptic segmentation with categories
- GroundingDINO + SAM: For text-prompted segmentation
- SAM 2: For video segmentation tasks
Quick start
Installation
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install opencv-python pycocotools matplotlib
pip install transformers
Download checkpoints
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Basic usage with SamPredictor
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device="cuda")
predictor = SamPredictor(sam)
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
HuggingFace Transformers
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to("cuda")
image = Image.open("image.jpg")
input_points = [[[450, 600]]]
inputs = processor(image, input_points=input_points, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
Core concepts
Model architecture
SAM Architecture:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
Image Embeddings Prompt Embeddings Masks + IoU
(computed once) (per prompt) predictions
Model variants
| Model | Checkpoint | Size | Speed | Accuracy |
|---|
| ViT-H | vit_h | 2.4 GB | Slowest | Best |
| ViT-L | vit_l | 1.2 GB | Medium | Good |
| ViT-B | vit_b | 375 MB | Fastest | Good |
Prompt types
| Prompt | Description | Use Case |
|---|
| Point (foreground) | Click on object | Single object selection |
| Point (background) | Click outside object | Exclude regions |
| Bounding box | Rectangle around object | Larger objects |
| Previous mask | Low-res mask input | Iterative refinement |
Interactive segmentation
Point prompts
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
input_points = np.array([[500, 375], [600, 400], [450, 300]])
input_labels = np.array([1, 1, 0])
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
Box prompts
input_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False
)
Combined prompts
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
box=np.array([400, 300, 700, 600]),
multimask_output=False
)
Iterative refinement
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375], [550, 400]]),
point_labels=np.array([1, 0]),
mask_input=logits[np.argmax(scores)][None, :, :],
multimask_output=False
)
Automatic mask generation
Basic automatic segmentation
from segment_anything import SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
Customized generation
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100,
)
masks = mask_generator.generate(image)
Filtering masks
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
Batched inference
Multiple images
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = []
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks)
Multiple prompts per image
predictor.set_image(image)
points = [
np.array([[100, 100]]),
np.array([[200, 200]]),
np.array([[300, 300]])
]
all_masks = []
for point in points:
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks[np.argmax(scores)])
ONNX deployment
Export model
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam_onnx.onnx \
--return-single-mask
Use ONNX model
import onnxruntime
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
masks = ort_session.run(
None,
{
"image_embeddings": image_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.array([0], dtype=np.float32),
"orig_im_size": np.array([h, w], dtype=np.float32)
}
)
Common workflows
Workflow 1: Annotation tool
import cv2
predictor = SamPredictor(sam)
predictor.set_image(image)
def on_click(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
masks, scores, _ = predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1]),
multimask_output=True
)
display_mask(masks[np.argmax(scores)])
Workflow 2: Object extraction
def extract_object(image, point):
"""Extract object at point with transparent background."""
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1]),
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
rgba[:, :, :3] = image
rgba[:, :, 3] = best_mask * 255
return rgba
Workflow 3: Medical image segmentation
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)
masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=True
)
Output format
Mask data structure
{
"segmentation": np.ndarray,
"bbox": [x, y, w, h],
"area": int,
"predicted_iou": float,
"stability_score": float,
"crop_box": [x, y, w, h],
"point_coords": [[x, y]],
}
COCO RLE format
from pycocotools import mask as mask_utils
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8")
decoded_mask = mask_utils.decode(rle)
Performance optimization
GPU memory
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
torch.cuda.empty_cache()
Speed optimization
sam = sam.half()
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16,
)
Common issues
| Issue | Solution |
|---|
| Out of memory | Use ViT-B model, reduce image size |
| Slow inference | Use ViT-B, reduce points_per_side |
| Poor mask quality | Try different prompts, use box + points |
| Edge artifacts | Use stability_score filtering |
| Small objects missed | Increase points_per_side |
References
Resources