跳转至

段一切割模型(SAM)

Segment Anything Model(SAM) 是由 Meta AI 开发的基础图像分割模型。它能够在零样本条件下分割图像中的任意目标——无需额外训练。

SAM 为何重要

在机器人领域,感知的核心是理解环境。分割回答的是"哪些像素属于哪个物体"这一基本问题:

  • 抓取:精确定位目标物体的边界
  • 导航:区分障碍物、地面和墙壁
  • 操作:理解物体形状以规划抓取姿态
  • 场景理解:构建环境的语义地图

传统分割模型只能处理固定的类别(如 COCO 的 80 个类别)。SAM 打破了这一限制——只需点击一下就能分割*任意*物体。

SAM 架构

SAM 由三个核心组件组成:

┌──────────────┐     ┌──────────────┐     ┌──────────────┐
│   图像        │────▶│   图像        │────▶│   掩码        │
│   编码器      │     │   嵌入        │     │   解码器      │
│ (ViT-based)  │     │  (可缓存)     │     │  (轻量级)     │
└──────────────┘     └──────────────┘     └──────────────┘
                                         ┌──────────────┐
                                         │   提示        │
                                         │   编码器      │
                                         │(点/框/文本/掩码)│
                                         └──────────────┘

图像编码器

  • 基于 Vision Transformer(ViT-H/14)
  • 处理一次图像并生成嵌入向量
  • 最重量级的组件(约 6 亿参数)
  • 对同一张图像的多次提示可复用嵌入

提示编码器

支持多种提示方式:

提示类型 说明 使用场景
点击物体上的位置 交互式分割
边界框 在物体周围画矩形 目标检测集成
粗略掩码 粗略的分割提示 掩码细化
文本 自然语言描述 基于 CLIP(实验性)

掩码解码器

  • 轻量级 Transformer 解码器(约 200 万参数)
  • 实时运行(每个提示约 50ms)
  • 输出一个或多个掩码及置信度分数
  • 歧义感知:当提示存在歧义时(如点击轮胎还是整辆车),返回多个掩码

SAM v1(2023)

核心特性

  • 零样本迁移:可处理从未训练过的物体和领域
  • 发布 3 个数据集
  • SA-1B:1100 万张图像、11 亿个掩码(史上最大的分割数据集)
  • SA-V:5.3 万段视频,带密集标注
  • SA-M:2000 张医学图像

快速上手 SAM v1

# 安装:pip install segment-anything
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np

# 加载模型(需先下载权重文件)
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)

# 加载图像
image = cv2.imread("robot_workspace.jpg")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)

# 用点提示 (x=500, y=375) — label 1 = 前景点
input_point = np.array([[500, 375]])
input_label = np.array([1])

# 预测
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,  # 返回 3 个不同粒度的掩码
)

# masks[0] 是最优掩码
print(f"掩码形状: {masks[0].shape}")    # (H, W) 布尔值
print(f"置信度分数: {scores}")           # [0.99, 0.95, 0.88]

使用边界框提示

# 用边界框 [x1, y1, x2, y2] 提示
input_box = np.array([100, 100, 400, 400])

masks, scores, logits = predictor.predict(
    box=input_box,
    multimask_output=False,  # 边界框提示返回单个掩码
)

组合点和框提示

# 用额外的点提示来细化边界框结果
input_box = np.array([100, 100, 400, 400])
input_point = np.array([[250, 250]])  # 框内的一个点
input_label = np.array([1])           # 前景点

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)

SAM 2(2024)

SAM 2 将模型从图像扩展到**视频**,实现了跨帧的实时物体分割与跟踪。

关键改进

特性 SAM v1 SAM 2
输入 仅图像 图像 + 视频
时序跟踪 ✅ 基于记忆机制
架构 ViT 编码器 Hiera(分层结构)
速度 ~50ms/帧 ~44ms/帧
交互式视频 ✅ 任意帧点/框提示
遮挡处理 ✅ 记忆机制

SAM 2 架构

┌────────────────────────────────────────────────┐
│                SAM 2 流水线                     │
│                                                │
│  帧 t:  图像 ──▶ Hiera 编码器 ──▶ 嵌入         │
│                              │                  │
│                   ┌──────────┘                  │
│                   ▼                             │
│              记忆库 ◀── 前序帧                   │
│             (存储历史                            │
│              物体信息)                           │
│                   │                             │
│                   ▼                             │
│              记忆注意力                          │
│                   │                             │
│                   ▼                             │
│              掩码解码器 ──▶ Mask_t               │
│                   ▲                             │
│              提示编码器                          │
│          (点/框/掩码)                           │
└────────────────────────────────────────────────┘

**记忆库**存储前序帧的嵌入信息,使模型能够在遮挡、形变和外观变化中持续跟踪物体。

快速上手 SAM 2

# 安装:pip install sam-2
from sam2.build_sam import build_sam2_video_predictor
import torch
import numpy as np

# 加载模型
predictor = build_sam2_video_predictor(
    "sam2_hiera_l.yaml",
    "sam2_hiera_large.pt"
)

# 处理视频
video_dir = "path/to/video/frames/"  # JPEG 帧所在目录
frame_names = sorted([
    f for f in os.listdir(video_dir)
    if f.endswith(('.jpg', '.jpeg'))
])

# 初始化状态
inference_state = predictor.init_state(video_dir=video_dir)

# 在第 0 帧添加提示
ann_frame_idx = 0
ann_obj_id = 1

# 点提示:(x, y),label: 1=前景, 0=背景
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], dtype=np.int32)

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# 在视频中传播 — SAM 2 自动跟踪物体
video_segments = {}
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[frame_idx] = {
        out_obj_id: (mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(obj_ids)
    }

交互式视频分割

SAM 2 支持交互式细化:如果跟踪在第 N 帧出现偏移,可以在该帧添加修正提示,传播会自动修正。

# 初始传播后,在第 50 帧进行修正
correction_frame = 50
correction_points = np.array([[220, 360]], dtype=np.float32)
correction_labels = np.array([1], dtype=np.int32)

# 重置并从修正点重新传播
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=correction_frame,
    obj_id=ann_obj_id,
    points=correction_points,
    labels=correction_labels,
)

# 从第 50 帧开始重新传播
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[frame_idx] = {
        out_obj_id: (mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(obj_ids)
    }

机器人应用

1. 物体抓取流水线

将 SAM 与深度相机结合用于机器人抓取:

import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import pyrealsense2 as rs

# 1. 从 RealSense 捕获 RGB + Depth
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
pipeline.start(config)

frames = pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
depth_frame = frames.get_depth_frame()

color_image = np.asanyarray(color_frame.get_data())
depth_image = np.asanyarray(depth_frame.get_data())

# 2. 用 SAM 分割目标物体
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
predictor.set_image(color_image)

# 点击目标物体
masks, scores, _ = predictor.predict(
    point_coords=np.array([[320, 240]]),
    point_labels=np.array([1]),
    multimask_output=False,
)

mask = masks[0]  # 布尔掩码 (H, W)

# 3. 提取物体的 3D 点云
depth_intrinsics = depth_frame.profile.as_video_stream_profile().intrinsics
object_points = []

for v in range(mask.shape[0]):
    for u in range(mask.shape[1]):
        if mask[v, u]:
            depth = depth_image[v, u] * 0.001  # 毫米转米
            if 0.1 < depth < 2.0:  # 有效范围
                point = rs.rs2_deproject_pixel_to_point(
                    depth_intrinsics, [u, v], depth
                )
                object_points.append(point)

object_points = np.array(object_points)
centroid = object_points.mean(axis=0)
print(f"物体质心: {centroid}")  # 用于抓取规划

2. 导航场景理解

# 分割场景中的所有物体
from segment_anything import SamAutomaticMaskGenerator

sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,          # 点提示密度
    pred_iou_thresh=0.88,        # 掩码质量阈值
    stability_score_thresh=0.95,  # 掩码稳定性阈值
    min_mask_region_area=100,    # 过滤微小区域
)

masks = mask_generator.generate(image_rgb)

# 每个掩码包含:segmentation, area, bbox, predicted_iou, stability_score
# 按面积排序(从大到小)
masks_sorted = sorted(masks, key=lambda x: x['area'], reverse=True)

for i, mask_data in enumerate(masks_sorted[:5]):
    print(f"物体 {i}: 面积={mask_data['area']}px, "
          f"bbox={mask_data['bbox']}, "
          f"iou={mask_data['predicted_iou']:.3f}")

3. 结合 CLIP 实现语言引导分割

import clip
import torch
from segment_anything import SamPredictor

def segment_by_text(image, text_query, sam_predictor, clip_model, clip_preprocess):
    """用文本描述分割物体"""
    h, w = image.shape[:2]

    # 1. 生成网格点提示
    grid_points = []
    for y in range(0, h, 32):
        for x in range(0, w, 32):
            grid_points.append([x, y])

    # 2. 对每个点裁剪 patch,计算 CLIP 相似度
    similarities = []
    for x, y in grid_points:
        x1, y1 = max(0, x-32), max(0, y-32)
        x2, y2 = min(w, x+32), min(h, y+32)
        patch = image[y1:y2, x1:x2]

        patch_input = clip_preprocess(Image.fromarray(patch)).unsqueeze(0)
        text_input = clip.tokenize([text_query])
        with torch.no_grad():
            image_features = clip_model.encode_image(patch_input)
            text_features = clip_model.encode_text(text_input)
            sim = (image_features @ text_features.T).item()
        similarities.append(sim)

    # 3. 用最相似的点作为 SAM 提示
    best_idx = np.argmax(similarities)
    best_point = np.array([grid_points[best_idx]])

    sam_predictor.set_image(image)
    masks, scores, _ = sam_predictor.predict(
        point_coords=best_point,
        point_labels=np.array([1]),
        multimask_output=False,
    )
    return masks[0]

进阶:微调 SAM 适配特定领域

SAM 虽然支持零样本分割,但微调可以提升特定领域物体的表现(如特定机器人部件、工业零件)。

from segment_anything import sam_model_registry
import torch

# 加载预训练 SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")

# 冻结图像编码器,只训练提示编码器和掩码解码器
for param in sam.image_encoder.parameters():
    param.requires_grad = False

# 自定义数据集
class CustomSegDataset(torch.utils.data.Dataset):
    def __init__(self, images, masks, points):
        self.images = images      # (H, W, 3) 数组列表
        self.masks = masks        # (H, W) 布尔数组列表
        self.points = points      # (N, 2) 点提示列表

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return {
            "image": torch.from_numpy(self.images[idx]).permute(2, 0, 1),
            "mask": torch.from_numpy(self.masks[idx]).float(),
            "points": torch.from_numpy(self.points[idx]).float(),
        }

# 用小数据集微调(100-500 张图像即可)
optimizer = torch.optim.Adam(
    list(sam.prompt_encoder.parameters()) +
    list(sam.mask_decoder.parameters()),
    lr=1e-5
)

criterion = torch.nn.BCEWithLogitsLoss()

for epoch in range(10):
    for batch in dataloader:
        image_embedding = sam.image_encoder(batch["image"])
        sparse_embeddings, dense_embeddings = sam.prompt_encoder(
            points=batch["points"], masks=None, boxes=None,
        )
        mask_predictions, _ = sam.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )

        loss = criterion(mask_predictions.squeeze(), batch["mask"])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

模型对比

模型 类型 零样本 交互式 视频 速度
SAM 通用 50ms
SAM 2 通用 44ms
YOLO-SAM 检测+分割 80ms
Mask R-CNN 实例分割 100ms
DeepLab 语义分割 60ms
Grounded-SAM 语言引导 120ms

SAM 生态变体

SAM 的生态已快速扩展:

  • Grounded-SAM:结合 Grounding DINO(文本检测)+ SAM,实现语言引导分割
  • Fast-SAM:基于 YOLOv8,比 SAM 快 50 倍,精度略低
  • Mobile-SAM:轻量版本,适用于边缘设备(约快 10 倍)
  • SAM-Med2D:针对医学影像微调
  • SAM 2.1:SAM 2 的小幅改进版本,跟踪一致性更好

总结

主题 要点
SAM v1 通用图像分割,支持点/框提示
SAM 2 扩展至视频,基于记忆机制的跟踪
机器人应用 抓取、导航、场景理解
微调 冻结编码器,用领域数据训练解码器
生态 Grounded-SAM、Fast-SAM、Mobile-SAM 等变体

参考资料