段一切割模型(SAM)¶
Segment Anything Model(SAM) 是由 Meta AI 开发的基础图像分割模型。它能够在零样本条件下分割图像中的任意目标——无需额外训练。
SAM 为何重要¶
在机器人领域,感知的核心是理解环境。分割回答的是"哪些像素属于哪个物体"这一基本问题:
- 抓取:精确定位目标物体的边界
- 导航:区分障碍物、地面和墙壁
- 操作:理解物体形状以规划抓取姿态
- 场景理解:构建环境的语义地图
传统分割模型只能处理固定的类别(如 COCO 的 80 个类别)。SAM 打破了这一限制——只需点击一下就能分割*任意*物体。
SAM 架构¶
SAM 由三个核心组件组成:
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ 图像 │────▶│ 图像 │────▶│ 掩码 │
│ 编码器 │ │ 嵌入 │ │ 解码器 │
│ (ViT-based) │ │ (可缓存) │ │ (轻量级) │
└──────────────┘ └──────────────┘ └──────────────┘
▲
│
┌──────────────┐
│ 提示 │
│ 编码器 │
│(点/框/文本/掩码)│
└──────────────┘
图像编码器¶
- 基于 Vision Transformer(ViT-H/14)
- 处理一次图像并生成嵌入向量
- 最重量级的组件(约 6 亿参数)
- 对同一张图像的多次提示可复用嵌入
提示编码器¶
支持多种提示方式:
| 提示类型 | 说明 | 使用场景 |
|---|---|---|
| 点 | 点击物体上的位置 | 交互式分割 |
| 边界框 | 在物体周围画矩形 | 目标检测集成 |
| 粗略掩码 | 粗略的分割提示 | 掩码细化 |
| 文本 | 自然语言描述 | 基于 CLIP(实验性) |
掩码解码器¶
- 轻量级 Transformer 解码器(约 200 万参数)
- 实时运行(每个提示约 50ms)
- 输出一个或多个掩码及置信度分数
- 歧义感知:当提示存在歧义时(如点击轮胎还是整辆车),返回多个掩码
SAM v1(2023)¶
核心特性¶
- 零样本迁移:可处理从未训练过的物体和领域
- 发布 3 个数据集:
- SA-1B:1100 万张图像、11 亿个掩码(史上最大的分割数据集)
- SA-V:5.3 万段视频,带密集标注
- SA-M:2000 张医学图像
快速上手 SAM v1¶
# 安装:pip install segment-anything
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
# 加载模型(需先下载权重文件)
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
# 加载图像
image = cv2.imread("robot_workspace.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)
# 用点提示 (x=500, y=375) — label 1 = 前景点
input_point = np.array([[500, 375]])
input_label = np.array([1])
# 预测
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True, # 返回 3 个不同粒度的掩码
)
# masks[0] 是最优掩码
print(f"掩码形状: {masks[0].shape}") # (H, W) 布尔值
print(f"置信度分数: {scores}") # [0.99, 0.95, 0.88]
使用边界框提示¶
# 用边界框 [x1, y1, x2, y2] 提示
input_box = np.array([100, 100, 400, 400])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False, # 边界框提示返回单个掩码
)
组合点和框提示¶
# 用额外的点提示来细化边界框结果
input_box = np.array([100, 100, 400, 400])
input_point = np.array([[250, 250]]) # 框内的一个点
input_label = np.array([1]) # 前景点
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
SAM 2(2024)¶
SAM 2 将模型从图像扩展到**视频**,实现了跨帧的实时物体分割与跟踪。
关键改进¶
| 特性 | SAM v1 | SAM 2 |
|---|---|---|
| 输入 | 仅图像 | 图像 + 视频 |
| 时序跟踪 | ❌ | ✅ 基于记忆机制 |
| 架构 | ViT 编码器 | Hiera(分层结构) |
| 速度 | ~50ms/帧 | ~44ms/帧 |
| 交互式视频 | ❌ | ✅ 任意帧点/框提示 |
| 遮挡处理 | ❌ | ✅ 记忆机制 |
SAM 2 架构¶
┌────────────────────────────────────────────────┐
│ SAM 2 流水线 │
│ │
│ 帧 t: 图像 ──▶ Hiera 编码器 ──▶ 嵌入 │
│ │ │
│ ┌──────────┘ │
│ ▼ │
│ 记忆库 ◀── 前序帧 │
│ (存储历史 │
│ 物体信息) │
│ │ │
│ ▼ │
│ 记忆注意力 │
│ │ │
│ ▼ │
│ 掩码解码器 ──▶ Mask_t │
│ ▲ │
│ 提示编码器 │
│ (点/框/掩码) │
└────────────────────────────────────────────────┘
**记忆库**存储前序帧的嵌入信息,使模型能够在遮挡、形变和外观变化中持续跟踪物体。
快速上手 SAM 2¶
# 安装:pip install sam-2
from sam2.build_sam import build_sam2_video_predictor
import torch
import numpy as np
# 加载模型
predictor = build_sam2_video_predictor(
"sam2_hiera_l.yaml",
"sam2_hiera_large.pt"
)
# 处理视频
video_dir = "path/to/video/frames/" # JPEG 帧所在目录
frame_names = sorted([
f for f in os.listdir(video_dir)
if f.endswith(('.jpg', '.jpeg'))
])
# 初始化状态
inference_state = predictor.init_state(video_dir=video_dir)
# 在第 0 帧添加提示
ann_frame_idx = 0
ann_obj_id = 1
# 点提示:(x, y),label: 1=前景, 0=背景
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], dtype=np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# 在视频中传播 — SAM 2 自动跟踪物体
video_segments = {}
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(inference_state):
video_segments[frame_idx] = {
out_obj_id: (mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(obj_ids)
}
交互式视频分割¶
SAM 2 支持交互式细化:如果跟踪在第 N 帧出现偏移,可以在该帧添加修正提示,传播会自动修正。
# 初始传播后,在第 50 帧进行修正
correction_frame = 50
correction_points = np.array([[220, 360]], dtype=np.float32)
correction_labels = np.array([1], dtype=np.int32)
# 重置并从修正点重新传播
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=correction_frame,
obj_id=ann_obj_id,
points=correction_points,
labels=correction_labels,
)
# 从第 50 帧开始重新传播
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(inference_state):
video_segments[frame_idx] = {
out_obj_id: (mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(obj_ids)
}
机器人应用¶
1. 物体抓取流水线¶
将 SAM 与深度相机结合用于机器人抓取:
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import pyrealsense2 as rs
# 1. 从 RealSense 捕获 RGB + Depth
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
pipeline.start(config)
frames = pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
depth_frame = frames.get_depth_frame()
color_image = np.asanyarray(color_frame.get_data())
depth_image = np.asanyarray(depth_frame.get_data())
# 2. 用 SAM 分割目标物体
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
predictor.set_image(color_image)
# 点击目标物体
masks, scores, _ = predictor.predict(
point_coords=np.array([[320, 240]]),
point_labels=np.array([1]),
multimask_output=False,
)
mask = masks[0] # 布尔掩码 (H, W)
# 3. 提取物体的 3D 点云
depth_intrinsics = depth_frame.profile.as_video_stream_profile().intrinsics
object_points = []
for v in range(mask.shape[0]):
for u in range(mask.shape[1]):
if mask[v, u]:
depth = depth_image[v, u] * 0.001 # 毫米转米
if 0.1 < depth < 2.0: # 有效范围
point = rs.rs2_deproject_pixel_to_point(
depth_intrinsics, [u, v], depth
)
object_points.append(point)
object_points = np.array(object_points)
centroid = object_points.mean(axis=0)
print(f"物体质心: {centroid}") # 用于抓取规划
2. 导航场景理解¶
# 分割场景中的所有物体
from segment_anything import SamAutomaticMaskGenerator
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # 点提示密度
pred_iou_thresh=0.88, # 掩码质量阈值
stability_score_thresh=0.95, # 掩码稳定性阈值
min_mask_region_area=100, # 过滤微小区域
)
masks = mask_generator.generate(image_rgb)
# 每个掩码包含:segmentation, area, bbox, predicted_iou, stability_score
# 按面积排序(从大到小)
masks_sorted = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, mask_data in enumerate(masks_sorted[:5]):
print(f"物体 {i}: 面积={mask_data['area']}px, "
f"bbox={mask_data['bbox']}, "
f"iou={mask_data['predicted_iou']:.3f}")
3. 结合 CLIP 实现语言引导分割¶
import clip
import torch
from segment_anything import SamPredictor
def segment_by_text(image, text_query, sam_predictor, clip_model, clip_preprocess):
"""用文本描述分割物体"""
h, w = image.shape[:2]
# 1. 生成网格点提示
grid_points = []
for y in range(0, h, 32):
for x in range(0, w, 32):
grid_points.append([x, y])
# 2. 对每个点裁剪 patch,计算 CLIP 相似度
similarities = []
for x, y in grid_points:
x1, y1 = max(0, x-32), max(0, y-32)
x2, y2 = min(w, x+32), min(h, y+32)
patch = image[y1:y2, x1:x2]
patch_input = clip_preprocess(Image.fromarray(patch)).unsqueeze(0)
text_input = clip.tokenize([text_query])
with torch.no_grad():
image_features = clip_model.encode_image(patch_input)
text_features = clip_model.encode_text(text_input)
sim = (image_features @ text_features.T).item()
similarities.append(sim)
# 3. 用最相似的点作为 SAM 提示
best_idx = np.argmax(similarities)
best_point = np.array([grid_points[best_idx]])
sam_predictor.set_image(image)
masks, scores, _ = sam_predictor.predict(
point_coords=best_point,
point_labels=np.array([1]),
multimask_output=False,
)
return masks[0]
进阶:微调 SAM 适配特定领域¶
SAM 虽然支持零样本分割,但微调可以提升特定领域物体的表现(如特定机器人部件、工业零件)。
from segment_anything import sam_model_registry
import torch
# 加载预训练 SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
# 冻结图像编码器,只训练提示编码器和掩码解码器
for param in sam.image_encoder.parameters():
param.requires_grad = False
# 自定义数据集
class CustomSegDataset(torch.utils.data.Dataset):
def __init__(self, images, masks, points):
self.images = images # (H, W, 3) 数组列表
self.masks = masks # (H, W) 布尔数组列表
self.points = points # (N, 2) 点提示列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return {
"image": torch.from_numpy(self.images[idx]).permute(2, 0, 1),
"mask": torch.from_numpy(self.masks[idx]).float(),
"points": torch.from_numpy(self.points[idx]).float(),
}
# 用小数据集微调(100-500 张图像即可)
optimizer = torch.optim.Adam(
list(sam.prompt_encoder.parameters()) +
list(sam.mask_decoder.parameters()),
lr=1e-5
)
criterion = torch.nn.BCEWithLogitsLoss()
for epoch in range(10):
for batch in dataloader:
image_embedding = sam.image_encoder(batch["image"])
sparse_embeddings, dense_embeddings = sam.prompt_encoder(
points=batch["points"], masks=None, boxes=None,
)
mask_predictions, _ = sam.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
loss = criterion(mask_predictions.squeeze(), batch["mask"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
模型对比¶
| 模型 | 类型 | 零样本 | 交互式 | 视频 | 速度 |
|---|---|---|---|---|---|
| SAM | 通用 | ✅ | ✅ | ❌ | 50ms |
| SAM 2 | 通用 | ✅ | ✅ | ✅ | 44ms |
| YOLO-SAM | 检测+分割 | ✅ | ❌ | ❌ | 80ms |
| Mask R-CNN | 实例分割 | ❌ | ❌ | ❌ | 100ms |
| DeepLab | 语义分割 | ❌ | ❌ | ❌ | 60ms |
| Grounded-SAM | 语言引导 | ✅ | ✅ | ❌ | 120ms |
SAM 生态变体¶
SAM 的生态已快速扩展:
- Grounded-SAM:结合 Grounding DINO(文本检测)+ SAM,实现语言引导分割
- Fast-SAM:基于 YOLOv8,比 SAM 快 50 倍,精度略低
- Mobile-SAM:轻量版本,适用于边缘设备(约快 10 倍)
- SAM-Med2D:针对医学影像微调
- SAM 2.1:SAM 2 的小幅改进版本,跟踪一致性更好
总结¶
| 主题 | 要点 |
|---|---|
| SAM v1 | 通用图像分割,支持点/框提示 |
| SAM 2 | 扩展至视频,基于记忆机制的跟踪 |
| 机器人应用 | 抓取、导航、场景理解 |
| 微调 | 冻结编码器,用领域数据训练解码器 |
| 生态 | Grounded-SAM、Fast-SAM、Mobile-SAM 等变体 |