Segment Anything Model (SAM)¶
The Segment Anything Model (SAM), developed by Meta AI, is a foundation model for image segmentation. It can segment any object in any image with zero-shot generalization — no additional training required.
Why SAM Matters for Robotics¶
In robotics, perception is about understanding the environment. Segmentation answers the question "which pixels belong to which object?" — a fundamental capability for:
- Grasping: Identifying the exact boundaries of target objects
- Navigation: Distinguishing obstacles, floors, and walls
- Manipulation: Understanding object shapes for planning grasp poses
- Scene understanding: Building semantic maps of the environment
Traditional segmentation models are trained on fixed categories (e.g., COCO's 80 classes). SAM breaks this limitation — you can segment anything by simply pointing at it.
SAM Architecture¶
SAM has three components:
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
│ Image │────▶│ Image │────▶│ Mask │
│ Encoder │ │ Embedding │ │ Decoder │
│ (ViT-based) │ │ (cached) │ │ (lightweight)│
└─────────────┘ └──────────────┘ └─────────────┘
▲
│
┌──────────────┐
│ Prompt │
│ Encoder │
│ (points/boxes/│
│ text/masks) │
└──────────────┘
Image Encoder¶
- Based on Vision Transformer (ViT-H/14)
- Processes the image once and produces an embedding
- This is the heaviest component (~600M parameters)
- The embedding can be cached for multiple prompts on the same image
Prompt Encoder¶
Accepts various types of prompts:
| Prompt Type | Description | Use Case |
|---|---|---|
| Points | Click on object(s) | Interactive segmentation |
| Bounding Box | Draw a rectangle around object | Object detection integration |
| Rough Mask | Coarse segmentation hint | Refinement |
| Text | Natural language description | CLIP-based (experimental) |
Mask Decoder¶
- Lightweight transformer decoder (~2M parameters)
- Runs in real-time (~50ms per prompt)
- Outputs one or more masks with confidence scores
- Ambiguity-aware: when a prompt is ambiguous (e.g., clicking on a wheel vs. the whole car), it returns multiple masks
SAM v1 (2023)¶
Key Features¶
- Zero-shot transfer: Works on objects and domains it was never trained on
- 3 datasets released:
- SA-1B: 11M images, 1.1B masks (largest segmentation dataset ever)
- SA-V: 53K videos with dense annotations
- SA-M: 2K medical images
Quick Start with SAM v1¶
# Install: pip install segment-anything
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
# Load model (download checkpoint first)
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
# Load image
image = cv2.imread("robot_workspace.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)
# Prompt with a point (x=500, y=375) — label 1 = foreground
input_point = np.array([[500, 375]])
input_label = np.array([1])
# Predict
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True, # returns 3 masks at different granularities
)
# masks[0] is the best mask
print(f"Mask shape: {masks[0].shape}") # (H, W) boolean
print(f"Confidence scores: {scores}") # [0.99, 0.95, 0.88]
Using Bounding Box Prompts¶
# Prompt with a bounding box [x1, y1, x2, y2]
input_box = np.array([100, 100, 400, 400])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False, # single mask for box prompts
)
Combining Points and Boxes¶
# Refine a box prompt with additional point prompts
input_box = np.array([100, 100, 400, 400])
input_point = np.array([[250, 250]]) # point inside box
input_label = np.array([1]) # foreground
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
SAM 2 (2024)¶
SAM 2 extends the model from images to videos, enabling real-time object segmentation and tracking across frames.
Key Improvements¶
| Feature | SAM v1 | SAM 2 |
|---|---|---|
| Input | Images only | Images + Videos |
| Temporal tracking | ❌ | ✅ Memory-based |
| Architecture | ViT encoder | Hiera (hierarchical) |
| Speed | ~50ms/frame | ~44ms/frame |
| Interactive video | ❌ | ✅ Point/box on any frame |
| Occlusion handling | ❌ | ✅ Memory mechanism |
Architecture of SAM 2¶
┌──────────────────────────────────────────────┐
│ SAM 2 Pipeline │
│ │
│ Frame t: Image ──▶ Hiera Encoder ──▶ Embedding
│ │ │
│ ┌──────────┘ │
│ ▼ │
│ Memory Bank ◀── Previous │
│ (stores past frames │
│ object info) │
│ │ │
│ ▼ │
│ Memory Attention │
│ │ │
│ ▼ │
│ Mask Decoder ──▶ Mask_t │
│ ▲ │
│ Prompt Encoder │
│ (points/boxes/masks) │
└──────────────────────────────────────────────┘
The Memory Bank stores embeddings from previous frames, allowing the model to track objects even through occlusions, deformations, and appearance changes.
Quick Start with SAM 2¶
# Install: pip install sam-2
from sam2.build_sam import build_sam2_video_predictor
import torch
import numpy as np
# Load model
predictor = build_sam2_video_predictor(
"sam2_hiera_l.yaml",
"sam2_hiera_large.pt"
)
# Process video
video_dir = "path/to/video/frames/" # directory of JPEG frames
frame_names = sorted([
f for f in os.listdir(video_dir)
if f.endswith(('.jpg', '.jpeg'))
])
# Initialize state
inference_state = predictor.init_state(video_dir=video_dir)
# Add a prompt on frame 0
ann_frame_idx = 0
ann_obj_id = 1
# Point: (x, y), label: 1=foreground, 0=background
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], dtype=np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# Propagate through video — SAM 2 tracks the object automatically
video_segments = {}
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(inference_state):
video_segments[frame_idx] = {
out_obj_id: (mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(obj_ids)
}
Interactive Video Segmentation¶
SAM 2 supports interactive refinement: if tracking drifts on frame N, you can add correction prompts on that frame, and propagation continues correctly.
# After initial propagation, correct on frame 50
correction_frame = 50
correction_points = np.array([[220, 360]], dtype=np.float32)
correction_labels = np.array([1], dtype=np.int32)
# Reset and re-propagate from the correction point
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=correction_frame,
obj_id=ann_obj_id,
points=correction_points,
labels=correction_labels,
)
# Re-propagate from frame 50 onwards
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(inference_state):
video_segments[frame_idx] = {
out_obj_id: (mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(obj_ids)
}
Robotics Applications¶
1. Object Grasping Pipeline¶
Combine SAM with a depth camera for robotic grasping:
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import pyrealsense2 as rs
# 1. Capture RGB + Depth from RealSense
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
pipeline.start(config)
frames = pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
depth_frame = frames.get_depth_frame()
color_image = np.asanyarray(color_frame.get_data())
depth_image = np.asanyarray(depth_frame.get_data())
# 2. Segment target object with SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
predictor.set_image(color_image)
# Point prompt on target object
masks, scores, _ = predictor.predict(
point_coords=np.array([[320, 240]]),
point_labels=np.array([1]),
multimask_output=False,
)
mask = masks[0] # boolean mask (H, W)
# 3. Extract 3D point cloud of the object
depth_intrinsics = depth_frame.profile.as_video_stream_profile().intrinsics
object_points = []
for v in range(mask.shape[0]):
for u in range(mask.shape[1]):
if mask[v, u]:
depth = depth_image[v, u] * 0.001 # mm to meters
if 0.1 < depth < 2.0: # valid range
point = rs.rs2_deproject_pixel_to_point(depth_intrinsics, [u, v], depth)
object_points.append(point)
object_points = np.array(object_points)
centroid = object_points.mean(axis=0)
print(f"Object center: {centroid}") # Use for grasp planning
2. Scene Understanding for Navigation¶
# Segment everything in the scene
from segment_anything import SamAutomaticMaskGenerator
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # density of point prompts
pred_iou_thresh=0.88, # mask quality threshold
stability_score_thresh=0.95, # mask stability threshold
min_mask_region_area=100, # filter tiny segments
)
masks = mask_generator.generate(image_rgb)
# Each mask has: segmentation, area, bbox, predicted_iou, stability_score
# Sort by area (largest first)
masks_sorted = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, mask_data in enumerate(masks_sorted[:5]):
print(f"Object {i}: area={mask_data['area']}px, "
f"bbox={mask_data['bbox']}, "
f"iou={mask_data['predicted_iou']:.3f}")
3. Combining SAM with CLIP for Language-Guided Segmentation¶
import clip
import torch
from segment_anything import SamPredictor
# Use CLIP to find the target object, then SAM to segment it
def segment_by_text(image, text_query, sam_predictor, clip_model, clip_preprocess):
# 1. Generate grid of point prompts
h, w = image.shape[:2]
grid_points = []
for y in range(0, h, 32):
for x in range(0, w, 32):
grid_points.append([x, y])
# 2. For each point, crop a patch and compute CLIP similarity
similarities = []
for x, y in grid_points:
# Crop 64x64 patch around point
x1, y1 = max(0, x-32), max(0, y-32)
x2, y2 = min(w, x+32), min(h, y+32)
patch = image[y1:y2, x1:x2]
# CLIP similarity
patch_input = clip_preprocess(Image.fromarray(patch)).unsqueeze(0)
text_input = clip.tokenize([text_query])
with torch.no_grad():
image_features = clip_model.encode_image(patch_input)
text_features = clip_model.encode_text(text_input)
sim = (image_features @ text_features.T).item()
similarities.append(sim)
# 3. Use the most similar point as SAM prompt
best_idx = np.argmax(similarities)
best_point = np.array([grid_points[best_idx]])
sam_predictor.set_image(image)
masks, scores, _ = sam_predictor.predict(
point_coords=best_point,
point_labels=np.array([1]),
multimask_output=False,
)
return masks[0]
Advanced: Fine-tuning SAM for Your Domain¶
While SAM works zero-shot, fine-tuning can improve performance on domain-specific objects (e.g., specific robot parts, industrial components).
from segment_anything import sam_model_registry, SamPredictor
import torch
# Load pretrained SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
# Freeze image encoder, only train prompt encoder + mask decoder
for param in sam.image_encoder.parameters():
param.requires_grad = False
# Your custom dataset
class CustomSegDataset(torch.utils.data.Dataset):
def __init__(self, images, masks, points):
self.images = images # list of (H, W, 3) arrays
self.masks = masks # list of (H, W) boolean arrays
self.points = points # list of (N, 2) point prompts
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return {
"image": torch.from_numpy(self.images[idx]).permute(2, 0, 1),
"mask": torch.from_numpy(self.masks[idx]).float(),
"points": torch.from_numpy(self.points[idx]).float(),
}
# Fine-tune with small dataset (100-500 images can be enough)
optimizer = torch.optim.Adam(
list(sam.prompt_encoder.parameters()) +
list(sam.mask_decoder.parameters()),
lr=1e-5
)
criterion = torch.nn.BCEWithLogitsLoss()
for epoch in range(10):
for batch in dataloader:
# Forward pass (simplified)
image_embedding = sam.image_encoder(batch["image"])
sparse_embeddings, dense_embeddings = sam.prompt_encoder(
points=batch["points"],
masks=None,
boxes=None,
)
mask_predictions, _ = sam.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
loss = criterion(mask_predictions.squeeze(), batch["mask"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
Comparison with Other Segmentation Models¶
| Model | Type | Zero-shot | Interactive | Video | Speed |
|---|---|---|---|---|---|
| SAM | Universal | ✅ | ✅ | ❌ | 50ms |
| SAM 2 | Universal | ✅ | ✅ | ✅ | 44ms |
| YOLO-SAM | Detection+Seg | ✅ | ❌ | ❌ | 80ms |
| Mask R-CNN | Instance | ❌ | ❌ | ❌ | 100ms |
| DeepLab | Semantic | ❌ | ❌ | ❌ | 60ms |
| Grounded-SAM | Language-guided | ✅ | ✅ | ❌ | 120ms |
Ecosystem: SAM Variants¶
The SAM ecosystem has expanded rapidly:
- Grounded-SAM: Combines Grounding DINO (text detection) + SAM for language-guided segmentation
- Fast-SAM: YOLOv8-based, 50x faster than SAM but slightly less accurate
- Mobile-SAM: Lightweight version for edge devices (~10x faster)
- SAM-Med2D: Fine-tuned for medical imaging
- SAM 2.1: Minor improvements to SAM 2 with better tracking consistency
Summary¶
| Topic | Key Takeaway |
|---|---|
| SAM v1 | Universal image segmentation with point/box prompts |
| SAM 2 | Extends to video with memory-based tracking |
| Robotics use | Grasping, navigation, scene understanding |
| Fine-tuning | Freeze encoder, train decoder on domain data |
| Ecosystem | Grounded-SAM, Fast-SAM, Mobile-SAM variants |
SAM 2.1 (Late 2024)¶
SAM 2.1 is an incremental but meaningful upgrade over SAM 2, released by Meta AI in late 2024. It focuses on improving tracking consistency and robustness in challenging video scenarios — exactly the kind of improvements that matter for real-time robotics.
Key Improvements over SAM 2¶
| Aspect | SAM 2 | SAM 2.1 |
|---|---|---|
| Tracking consistency | Occasional ID switches on fast motion | Reduced ID switches by ~40% |
| Memory mechanism | Fixed-length memory window | Adaptive memory with importance weighting |
| Fast-moving objects | Struggles with motion blur | Improved handling via motion-aware attention |
| Occlusion handling | Can lose track during partial occlusion | Better re-identification after occlusion |
| Refinement | Single-pass prediction | Iterative mask refinement (up to 3 rounds) |
Architecture Changes¶
SAM 2 Architecture:
┌────────────┐ ┌────────────┐ ┌────────────┐
│ Hiera │──▶│ Memory │──▶│ Mask │
│ Encoder │ │ Attention │ │ Decoder │
│ (ViT-H) │ │ (fixed) │ │ │
└────────────┘ └────────────┘ └────────────┘
SAM 2.1 Architecture:
┌────────────┐ ┌────────────────┐ ┌────────────┐
│ Enhanced │──▶│ Adaptive │──▶│ Iterative │
│ Hiera │ │ Memory with │ │ Mask │
│ Encoder │ │ Motion-Aware │ │ Decoder │
│ (ViT-H+) │ │ Attention │ │ (3-pass) │
└────────────┘ └────────────────┘ └────────────┘
The enhanced Hiera encoder incorporates motion-aware attention layers that explicitly model temporal displacement between frames. The adaptive memory mechanism assigns importance weights to stored features, prioritizing recent and visually salient frames over older, less relevant ones.
Performance on SA-V Dataset¶
| Metric | SAM 2 | SAM 2.1 | Improvement |
|---|---|---|---|
| J&F Score | 71.5 | 74.2 | +2.7 |
| Tracking Consistency (DAVIS) | 83.4 | 86.1 | +2.7 |
| ID Switch Rate | 12.3% | 7.8% | -4.5% |
| Occlusion Recovery | 68.9% | 76.3% | +7.4% |
| Speed (fps, A100) | 38 | 35 | -3 fps |
The slight speed reduction comes from the iterative refinement passes, but the accuracy gains — especially for tracking consistency — make it worthwhile for robotics applications where reliability matters.
Using SAM 2.1¶
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Build SAM 2.1 model
sam2_model = build_sam2(
model_cfg="sam2.1_hiera_l.yaml",
ckpt="sam2.1_hiera_large.pt",
device="cuda" if torch.cuda.is_available() else "cpu"
)
predictor = SAM2ImagePredictor(sam2_model)
# Load image
image = cv2.imread("scene.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)
# Segment with point prompt
masks, scores, logits = predictor.predict(
point_coords=np.array([[450, 300]]), # Click point
point_labels=np.array([1]), # 1 = foreground
multimask_output=True
)
# Select best mask
best_mask = masks[scores.argmax()]
print(f"Confidence: {scores.max():.3f}")
print(f"Mask shape: {best_mask.shape}")
# Video segmentation with SAM 2.1
from sam2.build_sam import build_sam2_video
sam2_video = build_sam2_video(
model_cfg="sam2.1_hiera_l.yaml",
ckpt="sam2.1_hiera_large.pt",
device="cuda"
)
# Process video frames
video_frames = load_video_frames("robot_view.mp4")
# Initialize with first frame prompt
predictor = SAM2VideoPredictor(sam2_video)
state = predictor.init_state(video_frames)
# Add prompt on first frame
ann_frame_idx = 0
ann_obj_id = 1
points = np.array([[500, 350]], dtype=np.float32)
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
state=state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels
)
# Track through entire video
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(state):
# mask_logits contains the segmentation mask for this frame
mask = (mask_logits > 0).cpu().numpy()
# Use mask for robot control...
Grounded-SAM¶
Grounded-SAM combines Grounding DINO (an open-vocabulary object detector) with SAM to enable text-prompted segmentation. Instead of clicking on an object, you describe it in natural language and Grounded-SAM segments it.
Architecture¶
┌────────────┐ ┌────────────────┐ ┌────────────┐
│ Text │────▶│ Grounding │────▶│ SAM │
│ Input │ │ DINO │ │ Encoder + │
│ "red cup" │ │ (detector) │ │ Decoder │
└────────────┘ └────────────────┘ └────────────┘
│ │
▼ ▼
┌────────────┐ ┌────────────┐
│ Bounding │ │ Seg masks │
│ Boxes │ │ per object│
└────────────┘ └────────────┘
The pipeline works in two stages: 1. Grounding DINO takes the text prompt and detects all objects matching the description, outputting bounding boxes 2. SAM takes those bounding boxes as prompts and generates precise pixel-level masks
Code Example: Text-Prompted Segmentation¶
import cv2
import numpy as np
from groundingdino.util.inference import load_model, predict
from segment_anything import sam_model_registry, SamPredictor
# Load Grounding DINO
grounding_model = load_model(
"groundingdino/config/GroundingDINO_SwinT_OGC.py",
"weights/groundingdino_swint_ogc.pth"
)
# Load SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h.pth")
sam_predictor = SamPredictor(sam)
# Load image
image = cv2.imread("kitchen_scene.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Step 1: Detect objects with Grounding DINO
TEXT_PROMPT = "the red cup on the table"
boxes, logits, phrases = predict(
model=grounding_model,
image=image_rgb,
caption=TEXT_PROMPT,
box_threshold=0.3,
text_threshold=0.25
)
print(f"Detected {len(boxes)} objects: {phrases}")
# Step 2: Segment with SAM using detected boxes
sam_predictor.set_image(image_rgb)
# Convert boxes from [cx, cy, w, h] to [x1, y1, x2, y2]
height, width = image.shape[:2]
boxes_xyxy = boxes.clone()
boxes_xyxy[:, 0] = (boxes[:, 0] - boxes[:, 2] / 2) * width
boxes_xyxy[:, 1] = (boxes[:, 1] - boxes[:, 3] / 2) * height
boxes_xyxy[:, 2] = (boxes[:, 0] + boxes[:, 2] / 2) * width
boxes_xyxy[:, 3] = (boxes[:, 1] + boxes[:, 3] / 2) * height
masks, scores, _ = sam_predictor.predict(
box=boxes_xyxy,
multimask_output=True
)
# Visualize results
for i, (mask, score, phrase) in enumerate(zip(masks, scores, phrases)):
color = np.random.randint(0, 255, 3)
overlay = image.copy()
overlay[mask] = overlay[mask] * 0.5 + color * 0.5
cv2.putText(overlay, f"{phrase}: {score:.2f}", (10, 30 + i * 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, tuple(map(int, color)), 2)
cv2.imwrite(f"mask_{i}.png", overlay)
print(f"Segmented: {phrase} (confidence: {score:.3f})")
Robotics Application: Language-Guided Manipulation¶
In a robotics setting, Grounded-SAM enables language-guided grasping. The robot receives a command like "pick up the red cup on the table", segments the target object, and uses the mask for grasp planning.
Language Command: "pick up the red cup on the table"
│
▼
┌─────────────────┐
│ Grounded-SAM │──▶ Segmentation mask of the cup
│ (text → mask) │
└─────────────────┘
│
▼
┌─────────────────┐
│ Grasp Planner │──▶ Grasp pose (position + orientation)
│ (mask → pose) │
└─────────────────┘
│
▼
┌─────────────────┐
│ Robot Arm │──▶ Execute grasp
│ Controller │
└─────────────────┘
# Simplified language-guided grasping pipeline
def language_guided_grasp(image, text_command, robot_arm):
"""
Segment an object by language description and grasp it.
Args:
image: Camera image from robot's viewpoint
text_command: Natural language description of target
robot_arm: Robot arm controller instance
"""
# 1. Segment the target object
mask = grounded_sam_segment(image, text_command)
# 2. Compute grasp from mask
# - Find object centroid for approach point
# - Compute principal axis for gripper orientation
# - Generate grasp candidates
grasp_candidates = grasp_from_mask(mask, image, depth_image)
# 3. Select best grasp (highest confidence, reachable)
best_grasp = select_best_grasp(grasp_candidates, robot_arm.workspace)
# 4. Execute
robot_arm.move_to(best_grasp.pre_grasp_pose)
robot_arm.open_gripper()
robot_arm.move_to(best_grasp.grasp_pose)
robot_arm.close_gripper()
robot_arm.move_to(best_grasp.lift_pose)
return best_grasp
Fast-SAM¶
Fast-SAM (by CASIA) is a real-time segmentation model that achieves ~50x speedup over SAM while maintaining reasonable accuracy. It is built on YOLOv8's architecture and is ideal for edge deployment on resource-constrained robots.
Architecture¶
┌──────────────────────────────────────────────────────┐
│ Fast-SAM │
├──────────────────────────────────────────────────────┤
│ │
│ ┌────────────┐ ┌────────────┐ ┌────────────────┐ │
│ │ YOLOv8 │ │ SAM-style │ │ Refinement │ │
│ │ Backbone │─▶│ Decoder │─▶│ Head │ │
│ │ (C2f+SPPE) │ │ (light) │ │ (optional) │ │
│ └────────────┘ └────────────┘ └────────────────┘ │
│ │
│ Total params: ~68M (vs SAM's 640M) │
└──────────────────────────────────────────────────────┘
Fast-SAM replaces SAM's heavy ViT encoder with a YOLOv8-based backbone that produces both bounding boxes and mask embeddings in a single forward pass.
When to Use Fast-SAM¶
| Scenario | Recommended? | Reason |
|---|---|---|
| Real-time robot segmentation (>30fps) | ✅ Yes | YOLO backbone is extremely fast |
| Edge deployment (Jetson, Raspberry Pi) | ✅ Yes | 68M params vs 640M |
| High-precision medical/scientific | ❌ No | Slightly lower accuracy |
| Video tracking with occlusion | ❌ No | No memory mechanism |
| Batch processing of static images | ⚠️ Maybe | SAM is more accurate |
Tradeoffs¶
| Metric | SAM | Fast-SAM | Ratio |
|---|---|---|---|
| Speed (images/sec, A100) | 1.5 | 75 | 50x faster |
| Speed (images/sec, CPU) | 0.02 | 1.2 | 60x faster |
| mIoU (COCO) | 0.83 | 0.75 | -10% |
| Parameters | 640M | 68M | 10x smaller |
| Model size | 2.5 GB | 170 MB | 15x smaller |
The ~10% accuracy drop is acceptable for most robotics tasks where speed is critical.
Code Example: Real-Time Segmentation¶
import cv2
import numpy as np
from ultralytics import FastSAM
# Load Fast-SAM model
model = FastSAM("FastSAM-s.pt") # s = small, l = large
def realtime_segmentation(camera_id=0):
"""
Real-time segmentation from webcam or robot camera.
"""
cap = cv2.VideoCapture(camera_id)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
while True:
ret, frame = cap.read()
if not ret:
break
# Segment everything (no prompt needed)
results = model(frame, retina_masks=True)
# Process results
masks = results[0].masks.data.cpu().numpy()
boxes = results[0].boxes.xyxy.cpu().numpy()
scores = results[0].boxes.conf.cpu().numpy()
# Draw results
overlay = frame.copy()
for i, (mask, box, score) in enumerate(zip(masks, boxes, scores)):
if score < 0.5:
continue
# Color mask overlay
color = np.random.randint(0, 255, 3, dtype=np.uint8)
mask_bool = mask.astype(bool)
overlay[mask_bool] = overlay[mask_bool] * 0.5 + color * 0.5
# Bounding box
x1, y1, x2, y2 = box.astype(int)
cv2.rectangle(overlay, (x1, y1), (x2, y2), tuple(map(int, color)), 2)
cv2.putText(overlay, f"Obj {i}: {score:.2f}",
(x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
tuple(map(int, color)), 2)
cv2.imshow("Fast-SAM Real-time", overlay)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
realtime_segmentation()
# Point-prompted segmentation with Fast-SAM
def segment_at_point(image_path, point_coords):
"""
Segment the object at a specific point using Fast-SAM.
Compatible interface with SAM's point prompt API.
"""
model = FastSAM("FastSAM-s.pt")
image = cv2.imread(image_path)
results = model(image, points=point_coords, retina_masks=True)
mask = results[0].masks.data[0].cpu().numpy().astype(bool)
return mask
# Usage
mask = segment_at_point("scene.jpg", [[450, 300]])
Mobile-SAM¶
Mobile-SAM is a lightweight version of SAM specifically designed for edge and mobile deployment. It uses knowledge distillation to compress SAM's capabilities into a model that runs ~10x faster on CPU and fits on devices like the NVIDIA Jetson Nano.
Architecture¶
┌─────────────────────────────────────────────────────┐
│ Mobile-SAM │
├─────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────────────────┐ │
│ │ Tiny ViT │──▶│ Distilled Decoder │ │
│ │ Encoder │ │ (lighter than SAM's) │ │
│ │ (25M params) │ │ (12M params) │ │
│ └──────────────┘ └──────────────────────────┘ │
│ │
│ Total: ~38M params (vs SAM's 640M) │
│ Distilled from SAM's full model via feature mimicry│
└─────────────────────────────────────────────────────┘
Mobile-SAM uses a two-stage distillation process: 1. Train the Tiny ViT encoder to match SAM's intermediate feature maps 2. Distill the decoder to produce similar mask outputs with fewer layers
Performance Comparison¶
| Device | SAM | Fast-SAM | Mobile-SAM |
|---|---|---|---|
| NVIDIA A100 | 1.5 img/s | 75 img/s | 15 img/s |
| RTX 3080 | 1.2 img/s | 55 img/s | 12 img/s |
| Jetson Nano | N/A (OOM) | 3 img/s | 2 img/s |
| Intel i7 CPU | 0.02 img/s | 1.2 img/s | 0.2 img/s |
| iPhone 15 (CoreML) | N/A | N/A | 4 img/s |
| Model size | 2.5 GB | 170 MB | 100 MB |
| mIoU (COCO) | 0.83 | 0.75 | 0.78 |
Mobile-SAM hits a sweet spot between accuracy and speed for embedded robotics.
Code Example¶
import cv2
import numpy as np
from mobile_sam import sam_model_registry, SamPredictor
# Build Mobile-SAM model
sam_checkpoint = "mobile_sam.pt"
model_type = "vit_t"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)
sam.eval()
predictor = SamPredictor(sam)
def segment_object(image_path, point_coords, point_labels=None):
"""
Segment an object using Mobile-SAM with point prompt.
Args:
image_path: Path to input image
point_coords: [x, y] coordinates of prompt point(s)
point_labels: 1 for foreground, 0 for background
"""
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)
if point_labels is None:
point_labels = np.ones(len(point_coords), dtype=np.int32)
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=point_labels,
multimask_output=True
)
best_idx = scores.argmax()
return masks[best_idx], scores[best_idx]
# Example: Segment a cup on a table
mask, score = segment_object(
"kitchen_scene.jpg",
point_coords=[[450, 300]]
)
print(f"Segmentation confidence: {score:.3f}")
print(f"Mask pixels: {mask.sum()}")
# Jetson Nano deployment with TensorRT optimization
import tensorrt as trt
def deploy_mobile_sam_jetson():
"""
Deploy Mobile-SAM on Jetson Nano with TensorRT.
"""
# Export to ONNX first
from mobile_sam import sam_model_registry
sam = sam_model_registry["vit_t"](checkpoint="mobile_sam.pt")
# Convert to TensorRT for Jetson optimization
# (run this on the development machine)
# python -m torch.onnx.export ... --opset 14
# On Jetson Nano, load the TensorRT engine
logger = trt.Logger(trt.Logger.WARNING)
with open("mobile_sam.engine", "rb") as f:
runtime = trt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(f.read())
# Run inference with TensorRT
# ... (standard TensorRT inference code)
print("Mobile-SAM loaded on Jetson Nano with TensorRT")
SAM in ROS Integration¶
This section demonstrates how to create a ROS2 node that integrates SAM for real-time segmentation in a robotics pipeline. The node subscribes to a camera image topic, accepts point prompts via a service call, and publishes the resulting segmentation mask.
ROS2 SAM Node Architecture¶
┌─────────────────────────────────────────────────────────┐
│ ROS2 SAM Node │
├─────────────────────────────────────────────────────────┤
│ │
│ Subscribers: Services: │
│ ┌──────────────┐ ┌──────────────────┐ │
│ │ /camera/image│ │ /sam/segment │ │
│ │ _raw │ │ (SegmentRequest) │ │
│ └──────┬───────┘ └────────┬─────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ SAM Segmenter │ │
│ │ (processes image + prompt → mask) │ │
│ └──────────────────┬───────────────────────┘ │
│ │ │
│ Publishers: ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ /sam/mask (sensor_msgs/Image) │ │
│ │ /sam/mask_overlay (sensor_msgs/Image) │ │
│ └──────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
Full ROS2 Python Node¶
#!/usr/bin/env python3
"""
ROS2 SAM Segmentation Node
Subscribes to camera images, accepts point/box prompts via service,
and publishes segmentation masks.
Launch with:
ros2 run perception sam_node --ros-args \
-p model_type:=vit_h \
-p checkpoint:=/path/to/sam_vit_h.pth
"""
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image, CameraInfo
from geometry_msgs.msg import Point
from std_msgs.msg import Header
from cv_bridge import CvBridge
from vision_msgs.msg import Mask, Point2D
import numpy as np
import cv2
import torch
from dataclasses import dataclass
from typing import Optional
class SAMNode(Node):
"""ROS2 node wrapping SAM for real-time segmentation."""
def __init__(self):
super().__init__('sam_segmentation_node')
# Declare parameters
self.declare_parameter('model_type', 'vit_h')
self.declare_parameter('checkpoint', 'sam_vit_h.pth')
self.declare_parameter('device', 'cuda')
self.declare_parameter('image_topic', '/camera/image_raw')
self.declare_parameter('confidence_threshold', 0.5)
# Get parameters
model_type = self.get_parameter('model_type').value
checkpoint = self.get_parameter('checkpoint').value
device = self.get_parameter('device').value
image_topic = self.get_parameter('image_topic').value
self.conf_threshold = self.get_parameter('confidence_threshold').value
# Initialize CV bridge
self.bridge = CvBridge()
self.current_image = None
self.current_image_cv = None
# Load SAM model
self.get_logger().info(f"Loading SAM model: {model_type}")
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry[model_type](checkpoint=checkpoint)
if device == "cuda" and torch.cuda.is_available():
sam = sam.cuda()
self.predictor = SamPredictor(sam)
self.get_logger().info("SAM model loaded successfully")
# Subscribers
self.image_sub = self.create_subscription(
Image, image_topic, self.image_callback, 10
)
# Publishers
self.mask_pub = self.create_publisher(
Image, '/sam/mask', 10
)
self.overlay_pub = self.create_publisher(
Image, '/sam/mask_overlay', 10
)
# Timer for periodic segmentation checks
self.create_timer(0.033, self.segment_callback) # ~30Hz
# Stored prompt
self.prompt_points = None
self.prompt_labels = None
self.get_logger().info(
f"SAM node started. Listening on: {image_topic}"
)
self.get_logger().info(
"Call /sam/segment service with point prompts to segment"
)
def image_callback(self, msg: Image):
"""Store latest camera image and convert to numpy."""
try:
self.current_image = msg
self.current_image_cv = self.bridge.imgmsg_to_cv2(
msg, desired_encoding='rgb8'
)
except Exception as e:
self.get_logger().error(f"Image conversion failed: {e}")
def set_prompt(self, points: list, labels: list):
"""
Set segmentation prompt.
Args:
points: List of (x, y) tuples in image coordinates
labels: List of labels (1=foreground, 0=background)
"""
self.prompt_points = np.array(points, dtype=np.float32)
self.prompt_labels = np.array(labels, dtype=np.int32)
self.get_logger().info(
f"Prompt set: {len(points)} points, labels={labels}"
)
def segment(self) -> Optional[tuple]:
"""
Run SAM segmentation with current image and prompt.
Returns:
Tuple of (best_mask, all_masks, scores) or None
"""
if self.current_image_cv is None:
self.get_logger().warn("No image received yet")
return None
if self.prompt_points is None:
return None
# Set image in predictor (only recomputes if image changed)
self.predictor.set_image(self.current_image_cv)
# Run prediction
masks, scores, logits = self.predictor.predict(
point_coords=self.prompt_points,
point_labels=self.prompt_labels,
multimask_output=True
)
self.get_logger().info(
f"Segmentation complete: {len(masks)} masks, "
f"best score={scores.max():.3f}"
)
return masks, logits, scores
def segment_callback(self):
"""Timer callback: segment and publish if prompt is set."""
result = self.segment()
if result is None:
return
masks, logits, scores = result
best_idx = scores.argmax()
best_mask = masks[best_idx]
# Publish best mask as binary image
mask_msg = self.bridge.cv2_to_imgmsg(
best_mask.astype(np.uint8) * 255,
encoding='mono8'
)
mask_msg.header = self.current_image.header
self.mask_pub.publish(mask_msg)
# Publish overlay
overlay = self.current_image_cv.copy()
overlay[best_mask] = (
overlay[best_mask] * 0.5
+ np.array([0, 255, 0]) * 0.5
)
overlay_msg = self.bridge.cv2_to_imgmsg(
overlay.astype(np.uint8),
encoding='rgb8'
)
overlay_msg.header = self.current_image.header
self.overlay_pub.publish(overlay_msg)
def destroy_node(self):
self.get_logger().info("Shutting down SAM node")
def main(args=None):
rclpy.init(args=args)
node = SAMNode()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()
Launch File¶
# launch/sam_launch.py
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
Node(
package='perception',
executable='sam_node',
name='sam_segmentation',
parameters=[{
'model_type': 'vit_h',
'checkpoint': '/models/sam_vit_h.pth',
'device': 'cuda',
'image_topic': '/camera/color/image_raw',
'confidence_threshold': 0.5,
}],
remappings=[],
),
])
Testing the Node¶
# Terminal 1: Launch SAM node
ros2 launch perception sam_launch.py
# Terminal 2: Provide a test image
ros2 topic pub --once /camera/image_raw sensor_msgs/Image \
--yaml-file test_image.yaml
# Terminal 3: Call segmentation service with a point prompt
# Using a simple Python script or service call:
python3 -c "
import rclpy
from rclpy.node import Node
from geometry_msgs.msg import Point
rclpy.init()
node = Node('test_client')
# ... call service with point at (450, 300)
"
# Terminal 4: View results
ros2 run rqt_image_view rqt_image_view /sam/mask_overlay
Evaluation Metrics for Segmentation¶
When deploying SAM or its variants on custom robotics datasets, you need quantitative metrics to evaluate performance. Here are the key metrics and how to compute them.
Core Metrics¶
mIoU (mean Intersection over Union)¶
The most widely used segmentation metric. For each object, it measures the overlap between predicted and ground truth masks.
mIoU = (1/N) * Σ_i (|predicted_i ∩ groundtruth_i| / |predicted_i ∪ groundtruth_i|)
Visualization:
Predicted mask: Ground truth: Intersection:
┌──────────┐ ┌──────────┐ ┌──────────┐
│ ████████ │ │ ████████│ │ ████████│
│ ████████ │ │ ████████│ │ ████████│
│ ████████ │ │ ████████│ │ ████████│
└──────────┘ └──────────┘ └──────────┘
IoU = area_of_intersection / area_of_union
mIoU = mean across all classes/objects
Dice Coefficient (F1 Score for Masks)¶
Measures the harmonic mean of precision and recall. Often used when class imbalance is significant (small objects in large images).
Dice = 2 * |predicted ∩ groundtruth| / (|predicted| + |groundtruth|)
Range: [0, 1] where 1 = perfect overlap
Note: Dice = 2 * IoU / (IoU + 1), so it is monotonically
related to IoU but emphasizes overlap more strongly.
Boundary F1 Score¶
Evaluates how well the predicted boundary aligns with the ground truth boundary. This is critical for robotics where precise object edges matter for grasping.
Boundary F1:
1. Dilate predicted mask → get predicted boundary
2. Dilate ground truth mask → get ground truth boundary
3. Compute precision and recall of boundary pixels:
- TP = boundary pixels within tolerance of GT boundary
- FP = predicted boundary pixels far from GT boundary
- FN = GT boundary pixels far from predicted boundary
4. F1 = 2 * TP / (2 * TP + FP + FN)
Computing Metrics¶
import numpy as np
from scipy import ndimage
def compute_iou(mask_pred: np.ndarray, mask_gt: np.ndarray) -> float:
"""Compute Intersection over Union between two binary masks."""
intersection = np.logical_and(mask_pred, mask_gt).sum()
union = np.logical_or(mask_pred, mask_gt).sum()
if union == 0:
return 0.0
return intersection / union
def compute_dice(mask_pred: np.ndarray, mask_gt: np.ndarray) -> float:
"""Compute Dice coefficient between two binary masks."""
intersection = np.logical_and(mask_pred, mask_gt).sum()
total = mask_pred.sum() + mask_gt.sum()
if total == 0:
return 1.0 # Both empty = perfect agreement
return 2 * intersection / total
def compute_boundary_f1(
mask_pred: np.ndarray,
mask_gt: np.ndarray,
tolerance: int = 2
) -> float:
"""
Compute Boundary F1 score.
Args:
mask_pred: Predicted binary mask
mask_gt: Ground truth binary mask
tolerance: Tolerance in pixels for boundary matching
"""
# Extract boundaries
pred_boundary = (
mask_pred.astype(np.uint8)
- ndimage.binary_erosion(mask_pred).astype(np.uint8)
)
gt_boundary = (
mask_gt.astype(np.uint8)
- ndimage.binary_erosion(mask_gt).astype(np.uint8)
)
if pred_boundary.sum() == 0 and gt_boundary.sum() == 0:
return 1.0
# Dilate boundaries by tolerance
struct = ndimage.generate_binary_structure(2, 1)
pred_dilated = ndimage.binary_dilation(
pred_boundary, structure=struct, iterations=tolerance
)
gt_dilated = ndimage.binary_dilation(
gt_boundary, structure=struct, iterations=tolerance
)
# True positives: predicted boundary within tolerance of GT
tp = np.logical_and(pred_boundary, gt_dilated).sum()
# False positives: predicted boundary far from GT
fp = np.logical_and(pred_boundary, ~gt_dilated).sum()
# False negatives: GT boundary far from predicted
fn = np.logical_and(gt_boundary, ~pred_dilated).sum()
if tp + fp + fn == 0:
return 1.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def evaluate_sam_on_dataset(
predictor,
dataset: list,
prompt_type: str = "point"
) -> dict:
"""
Evaluate SAM on a custom dataset.
Args:
predictor: SAM predictor instance
dataset: List of dicts with keys:
'image', 'mask_gt', 'prompt_points', 'prompt_labels'
prompt_type: Type of prompt ('point', 'box', 'auto')
Returns:
Dictionary of metrics
"""
ious = []
dices = []
boundary_f1s = []
for sample in dataset:
image = sample['image']
gt_mask = sample['mask_gt']
predictor.set_image(image)
# Generate prediction based on prompt type
if prompt_type == "point":
masks, scores, _ = predictor.predict(
point_coords=sample['prompt_points'],
point_labels=sample['prompt_labels'],
multimask_output=True
)
elif prompt_type == "box":
masks, scores, _ = predictor.predict(
box=sample['prompt_box'],
multimask_output=True
)
pred_mask = masks[scores.argmax()]
# Compute metrics
ious.append(compute_iou(pred_mask, gt_mask))
dices.append(compute_dice(pred_mask, gt_mask))
boundary_f1s.append(compute_boundary_f1(pred_mask, gt_mask))
results = {
'mIoU': np.mean(ious),
'std_IoU': np.std(ious),
'Dice': np.mean(dices),
'std_Dice': np.std(dices),
'Boundary_F1': np.mean(boundary_f1s),
'std_BF1': np.std(boundary_f1s),
'num_samples': len(dataset),
}
print("=" * 50)
print("SAM Evaluation Results")
print("=" * 50)
print(f" Samples: {results['num_samples']}")
print(f" mIoU: {results['mIoU']:.4f} "
f"± {results['std_IoU']:.4f}")
print(f" Dice: {results['Dice']:.4f} "
f"± {results['std_Dice']:.4f}")
print(f" Boundary F1: {results['Boundary_F1']:.4f} "
f"± {results['std_BF1']:.4f}")
print("=" * 50)
return results
Practical Tips¶
| Tip | Details |
|---|---|
| Use Boundary F1 for grasping | Grasp planners need precise edges, not just approximate regions |
| Evaluate per-object-class | SAM may work great on cups but poorly on thin wires |
| Test with realistic prompts | Simulate real robot click noise (±2px jitter on points) |
| Measure latency too | Accuracy is useless if inference takes >200ms for real-time control |
| Create a holdout set | Split your custom dataset: 70% fine-tuning, 15% validation, 15% test |