跳转至

视觉目标跟踪 (Visual Object Tracking)

项目类型: 感知 | 难度: ★★☆☆☆ 到 ★★★★★ | 预计时间: 1–3 个周末


1. 项目概述

视觉目标跟踪(Visual Object Tracking, VOT)的任务是在给定第一帧目标初始边界框的条件下,估计目标在视频各帧中的轨迹。与目标检测不同,跟踪必须在一段时间内保持**目标身份**,并处理遮挡、形变和运动模糊等问题。

  ┌────────────────────────────────────────────────────────────┐
  │                   跟踪流程图                                 │
  │                                                             │
  │   第 t-1 帧        第 t 帧          第 t+1 帧                 │
  │   ┌────────┐       ┌────────┐      ┌────────┐               │
  │   │ ▓▓▓▓▓ │──────►│ ▓▓▓▓▓ │─────►│ ▓▓▓▓▓ │               │
  │   │ ▓▓▓▓▓ │       │ ▓▓▓▓▓ │      │ ▓▓▓▓▓ │               │
  │   │ ID: 1  │       │ ID: 1  │      │ ID: 1  │               │
  │   └────────┘       └────────┘      └────────┘               │
  │        │               │               │                     │
  │        └───────────────┴───────────────┘                     │
  │                    跟踪 ID: 1                                 │
  │   ┌──────────────────────────────────────────────┐            │
  │   │  检测 → 预测 → 关联 → 更新                    │            │
  │   └──────────────────────────────────────────────┘            │
  └────────────────────────────────────────────────────────────┘

本项目将实现**三个难度递进的算法层次**:

层次 方法 核心技术 速度 精度
传统方法 KCF 跟踪器 核化相关滤波器 + HOG ~300 FPS ★★★☆☆
中级方法 SORT / DeepSORT 检测 + 卡尔曼滤波 + 匈牙利算法 ~100 FPS ★★★★☆
现代方法 Transformer 跟踪器 注意力机制 + 模板匹配 ~30 FPS ★★★★★

2. 硬件与软件需求

硬件

组件 规格 备注
相机 USB 摄像头或 Pi Camera 最小 640×480,30 FPS
(可选)GPU NVIDIA 显卡,显存 ≥ 4 GB 用于 DeepSORT 和 Transformer 跟踪器
机器人 / 测试场景 任意移动的物体或人物 用于实际测试

软件

版本 用途
Python ≥ 3.8 核心语言
OpenCV ≥ 4.5 图像处理与跟踪
NumPy ≥ 1.20 数值计算
filterpy ≥ 1.4 卡尔曼滤波器实现
scipy ≥ 1.7 线性代数、匈牙利算法
torch ≥ 1.10 深度学习(DeepSORT, Transformer)
torchvision ≥ 0.11 预训练 CNN 模型
ultralytics ≥ 8.0 YOLO 检测(用于 SORT/DeepSORT)
pip install opencv-python numpy filterpy scipy torch torchvision ultralytics

3. 第一层 — 传统方法:KCF 跟踪器

3.1 核心思想

核化相关滤波器(Kernelized Correlation Filter, KCF)将目标跟踪问题建模为**傅里叶域中的岭回归**问题。通过利用图像块的**循环矩阵结构**,KCF 在保持竞争力的精度的同时实现了极高的速度。

核心思想:

  1. 循环矩阵:图像块的平移操作生成一个循环矩阵。循环矩阵的所有特征值由第一行的离散傅里叶变换(DFT)给出——这就是**循环矩阵定理**。

  2. 核技巧:使用核函数(如高斯 RBF)将特征映射到高维空间,实现非线性边界估计,而无需显式计算高维表示。

  3. 快速检测:在傅里叶域中,求解核岭回归退化为逐元素除法——使检测极其快速。

3.2 数学公式

训练阶段:给定从训练帧中提取的一组图像块 \(x_i\),求解岭回归:

\[ \min_{\mathbf{w}} \sum_i \left\| f(\mathbf{x}_i) - y_i \right\|^2 + \lambda \|\mathbf{w}\|^2 \]

其中 \(y_i\) 是高斯形状的回归目标(以目标为中心点的 2D 高斯分布)。

傅里叶域求解(循环矩阵 \(\mathbf{C}(\mathbf{x})\)):

\[ \mathbf{\hat{w}} = \frac{\hat{\mathbf{x}} \odot \hat{\mathbf{y}}^*}{\hat{\mathbf{x}}^* \odot \hat{\mathbf{x}} + \lambda} \]

其中 \(\hat{}\) 表示 DFT,\(\odot\) 是逐元素乘法,\(^*\) 是复共轭。

核映射:使用核 \(\kappa(\mathbf{x}, \mathbf{x}')\),分类器变为:

\[ f(\mathbf{z}) = \mathbf{w}^T \phi(\mathbf{z}) = \sum_i \alpha_i \kappa(\mathbf{x}_i, \mathbf{z}) \]

其中 \(\alpha_i\) 是学习到的权重。

检测:对于新图像块 \(\mathbf{z}\),计算相关响应:

\[ \hat{\mathbf{f}}(\mathbf{z}) = \hat{\mathbf{k}}^{\mathbf{xz}} \odot \hat{\boldsymbol{\alpha}} \]

响应图的峰值指示新的目标中心。

3.3 完整 Python 代码 — KCF 跟踪器

"""
第一层:核化相关滤波器(KCF)跟踪器
====================================
从头实现,使用 HOG 特征 + 高斯核岭回归。
生产环境建议使用 OpenCV 的 cv2.TrackerKCF_create()。

功能:
- HOG(方向梯度直方图)特征提取
- HOG 特征空间中的高斯核
- 利用 FFT 实现循环矩阵理论,加速训练/检测
"""

import numpy as np
import cv2
import time


# ─── HOG 特征提取器 ────────────────────────────────────────────────

def compute_hog_features(gray: np.ndarray, cell_size: int = 4, num_bins: int = 9) -> np.ndarray:
    """
    从灰度图像计算 HOG 特征。

    Parameters
    ----------
    gray : 2D array (H x W)
        灰度图像(已裁剪到 patch 区域)
    cell_size : int
        每个 HOG cell 的像素大小
    num_bins : int
        方向 bin 数量

    Returns
    -------
    features : 1D array
        展平后的 HOG 特征向量
    """
    h, w = gray.shape
    n_cells_x = w // cell_size
    n_cells_y = h // cell_size

    # 1. 计算图像梯度
    gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=1)
    gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=1)
    magnitude = np.sqrt(gx**2 + gy**2)
    orientation = np.arctan2(gy, gx) * (180 / np.pi) % 180  # [0, 180)

    # 2. 计算每个 cell 的直方图
    cell_histograms = np.zeros((n_cells_y, n_cells_x, num_bins), dtype=np.float32)
    bin_width = 180 / num_bins

    for cy in range(n_cells_y):
        for cx in range(n_cells_x):
            y0, y1 = cy * cell_size, (cy + 1) * cell_size
            x0, x1 = cx * cell_size, (cx + 1) * cell_size

            mag_patch = magnitude[y0:y1, x0:x1].flatten()
            ori_patch = orientation[y0:y1, x0:x1].flatten()

            hist = np.zeros(num_bins, dtype=np.float32)
            for m, o in zip(mag_patch, ori_patch):
                bin_idx = int(o / bin_width) % num_bins
                hist[bin_idx] += m

            cell_histograms[cy, cx] = hist

    # 3. 块归一化(L2 范数)
    eps = 1e-5
    features = []
    for by in range(n_cells_y - 1):
        for bx in range(n_cells_x - 1):
            block = cell_histograms[by:by+2, bx:bx+2].flatten()
            norm = np.sqrt(np.sum(block**2) + eps)
            features.extend(block / norm)

    return np.array(features, dtype=np.float32)


def gaussian_response_map(size: tuple, sigma: float = 2.0) -> np.ndarray:
    """
    生成 2D 高斯响应图作为回归目标。

    Parameters
    ----------
    size : (H, W)
        响应图尺寸
    sigma : float
        高斯标准差

    Returns
    -------
    response : 2D array
        高斯形状目标图
    """
    h, w = size
    cy, cx = h // 2, w // 2
    y, x = np.ogrid[:h, :w]
    response = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * sigma**2))
    return response.astype(np.float32)


def gaussian_kernel(x1: np.ndarray, x2: np.ndarray, sigma: float = 0.5) -> np.ndarray:
    """
    计算高斯核矩阵:所有配对的 k(x1_i, x2_j)。
    利用循环结构提高效率。

    Parameters
    ----------
    x1, x2 : 1D arrays (D,)
        特征向量
    sigma : float
        核带宽

    Returns
    -------
    k : 2D array (N x M)
        核矩阵
    """
    n = len(x1)
    c = np.zeros(n, dtype=np.float32)
    c[0] = np.exp(-np.sum((x1 - x2)**2) / (2 * sigma**2))
    for i in range(1, n):
        shift = i % n
        c[i] = c[0]  # 简化版:实际实现需考虑真实偏移

    hat_c = np.fft.fft(c)
    return hat_c


class KCFTacker:
    """
    使用 HOG 特征的核化相关滤波器跟踪器。
    """

    def __init__(self, patch_size: int = 64, lambda_reg: float = 1e-4,
                 sigma: float = 0.5, interp_factor: float = 0.075):
        self.patch_size = patch_size          # 搜索/训练 patch 大小
        self.lambda_reg = lambda_reg          # 正则化参数
        self.sigma = sigma                    # 高斯核带宽
        self.interp_factor = interp_factor   # 模型更新率

        self.model_alphas = None  # 相关滤波器的 FFT
        self.model_x = None        # HOG 特征的 FFT(训练 patch)
        self.target_pos = None     # 图像坐标中的 (x, y)
        self.target_sz = None      # (宽度, 高度)

    def init(self, frame: np.ndarray, bbox: tuple) -> None:
        """
        用第一帧和边界框初始化跟踪器。

        Parameters
        ----------
        frame : BGR 图像
        bbox : (x, y, w, h) — 左上角坐标 + 尺寸
        """
        x, y, w, h = bbox
        cx, cy = x + w / 2, y + h / 2
        self.target_pos = (float(cx), float(cy))
        self.target_sz = (float(w), float(h))

        # 提取并调整 patch 大小
        patch = self._extract_patch(frame, self.target_pos, self.target_sz, self.patch_size)
        gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0

        # 提取 HOG 特征
        self.model_x = compute_hog_features(gray)

        # 生成回归目标(高斯标签)
        y_target = gaussian_response_map((gray.shape[0], gray.shape[1]), sigma=2.0)

        # 训练滤波器:傅里叶域求解 alphas
        x_fft = np.fft.fft2(self.model_x.reshape(gray.shape))
        y_fft = np.fft.fft2(y_target)

        # KCF 解:alpha = y / (x * x_conj + lambda)
        x_power = np.abs(x_fft)**2
        self.model_alphas = np.fft.fft2(y_target) / (x_power + self.lambda_reg)

    def update(self, frame: np.ndarray) -> tuple:
        """
        在当前帧中跟踪目标。

        Returns
        -------
        bbox : (x, y, w, h)
        """
        if self.target_pos is None:
            raise ValueError("跟踪器未初始化。请先调用 init()。")

        # 提取搜索 patch(比目标稍大)
        search_sz = (self.target_sz[0] * 1.5, self.target_sz[1] * 1.5)
        patch = self._extract_patch(frame, self.target_pos, search_sz, self.patch_size)
        gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0

        # 提取 HOG 特征
        feat = compute_hog_features(gray)

        # 通过 FFT 计算响应图
        feat_fft = np.fft.fft2(feat.reshape(gray.shape))
        response_fft = feat_fft * self.model_alphas
        response = np.real(np.fft.ifft2(response_fft))

        # 找到峰值(最大响应位置)
        peak_idx = np.unravel_index(np.argmax(response), response.shape)
        dy, dx = peak_idx[0] - response.shape[0] // 2, peak_idx[1] - response.shape[1] // 2

        # 将位移转换到图像坐标
        scale_x = search_sz[0] / self.patch_size
        scale_y = search_sz[1] / self.patch_size
        dx_px = dx * scale_x
        dy_px = dy * scale_y

        # 更新目标位置
        self.target_pos = (
            self.target_pos[0] + dx_px,
            self.target_pos[1] + dy_px
        )

        # 在线模型更新(插值)
        patch_new = self._extract_patch(frame, self.target_pos, self.target_sz, self.patch_size)
        gray_new = cv2.cvtColor(patch_new, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
        feat_new = compute_hog_features(gray_new)

        self.model_x = (1 - self.interp_factor) * self.model_x + self.interp_factor * feat_new
        feat_fft_new = np.fft.fft2(self.model_x.reshape(gray_new.shape))
        x_power_new = np.abs(feat_fft_new)**2
        alpha_new = np.fft.fft2(gaussian_response_map(gray_new.shape)) / (x_power_new + self.lambda_reg)
        self.model_alphas = (1 - self.interp_factor) * self.model_alphas + self.interp_factor * alpha_new

        return self._bbox_from_center(self.target_pos, self.target_sz)

    def _extract_patch(self, frame: np.ndarray, center: tuple,
                       size: tuple, patch_size: int) -> np.ndarray:
        """在 center 处提取并调整大小为 size 的 patch。"""
        x, y = int(center[0] - size[0] / 2), int(center[1] - size[1] / 2)
        patch = frame[y:y+int(size[1]), x:x+int(size[0])]
        if patch.size == 0:
            return np.zeros((patch_size, patch_size, 3), dtype=np.uint8)
        return cv2.resize(patch, (patch_size, patch_size))

    def _bbox_from_center(self, center: tuple, size: tuple) -> tuple:
        """将 (cx, cy, w, h) 中心格式转换为 (x, y, w, h) 左上角格式。"""
        return (int(center[0] - size[0] / 2),
                int(center[1] - size[1] / 2),
                int(size[0]),
                int(size[1]))


def demo_kcf_tracker(video_source: int = 0, initial_bbox: tuple = None):
    """
    在实时摄像头画面上演示 KCF 跟踪器。
    用鼠标绘制边界框来选择目标。
    """
    cap = cv2.VideoCapture(video_source)
    if not cap.isOpened():
        print("[ERROR] 无法打开摄像头")
        return

    # 获取第一帧并让用户选择 bbox
    ret, frame = cap.read()
    if not ret:
        print("[ERROR] 无法读取第一帧")
        return

    if initial_bbox is None:
        print("[INFO] 在图像上绘制边界框,然后按 ENTER 或 SPACE")
        bbox = cv2.selectROI("选择目标", frame, fromCenter=False, showCrosshair=True)
        cv2.destroyWindow("选择目标")
    else:
        bbox = initial_bbox

    # 初始化跟踪器
    tracker = KCFTacker()
    tracker.init(frame, bbox)

    print("[INFO] KCF 跟踪已启动。按 'q' 退出。")
    fps_counter, last_time = 0, time.time()

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        t0 = time.time()
        tracked_bbox = tracker.update(frame)
        elapsed = time.time() - t0

        fps = 1.0 / elapsed if elapsed > 0 else 0
        fps_counter += 1

        # 绘制边界框
        x, y, w, h = tracked_bbox
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
        cv2.putText(frame, f"KCF  FPS: {fps:.1f}", (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)

        cv2.imshow("KCF Tracker", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    print(f"[INFO] 平均 FPS: {fps_counter / max(1, time.time() - last_time):.1f}")


if __name__ == "__main__":
    # 示例:跟踪第一帧中检测到的运动,或手动指定 bbox
    demo_kcf_tracker(video_source=0, initial_bbox=None)

3.4 OpenCV 内置 KCF

OpenCV 提供了现成的 KCF 实现:

import cv2

tracker = cv2.TrackerKCF_create()
# 或更简单的测试:
# tracker = cv2.TrackerMOSSE_create()  # 更快,~1000 FPS

ret, frame = cap.read()
bbox = cv2.selectROI("选择", frame, fromCenter=False)
tracker.init(frame, bbox)

while True:
    ret, frame = cap.read()
    if not ret:
        break
    success, bbox = tracker.update(frame)
    if success:
        x, y, w, h = [int(v) for v in bbox]
        cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
    cv2.imshow("KCF", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

4. 第二层 — 中级方法:SORT 与 DeepSORT

4.1 核心思想

SORT(Simple Online and Realtime Tracking)和 DeepSORT(Deep Association Metric)是**基于检测的多目标跟踪器**。它们不是直接搜索目标,而是:

  1. 检测:在每一帧中使用目标检测器(如 YOLO)检测所有目标
  2. 预测:使用卡尔曼滤波器预测每个跟踪轨迹在当前帧中的位置
  3. 关联:使用分配算法(匈牙利算法)将检测结果分配给跟踪轨迹
  4. 更新:用匹配到的检测结果更新卡尔曼滤波器
  ┌─────────────────────────────────────────────────────────────┐
  │                   SORT / DeepSORT 流程图                    │
  │                                                              │
  │   第 t-1 帧 ──► YOLO 检测 ──► ┌─────────────┐               │
  │                               │ 卡尔曼       │               │
  │   第 t 帧 ────► YOLO 检测 ──►│ 滤波器       │               │
  │                               │ 预测         │               │
  │   第 t+1 帧 ──► YOLO 检测 ──►│              │               │
  │                               │ 匈牙利       │               │
  │                               │ 分配         │               │
  │                               └──────┬──────┘               │
  │   跟踪轨迹(ID, bbox, age) ◄────────┘                      │
  └─────────────────────────────────────────────────────────────┘

4.2 数学基础

4.2.1 卡尔曼滤波器

对于 2D 跟踪,我们将每个目标建模为具有**位置 \((x, y)\)、尺寸 \((w, h)\) 和速度 \((\dot{x}, \dot{y})\)**。状态向量为:

\[ \mathbf{x} = [x, y, w, h, \dot{x}, \dot{y}]^T \]

状态转移模型(恒定速度):

\[ \mathbf{x}_t = \mathbf{F} \cdot \mathbf{x}_{t-1} + \mathbf{w}_t \]

其中:

\[ \mathbf{F} = \begin{bmatrix} 1 & 0 & 0 & 0 & dt & 0 \\ 0 & 1 & 0 & 0 & 0 & dt \\ 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 \end{bmatrix} \]

观测模型(我们只观测位置和尺寸):

\[ \mathbf{z}_t = \mathbf{H} \cdot \mathbf{x}_t + \mathbf{v}_t \]

其中 \(\mathbf{H} = \begin{bmatrix} 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 0 & 0 \end{bmatrix}\)

4.2.2 匈牙利算法

匈牙利算法(Kuhn-Munkres)求解 \(n\) 个检测结果与 \(m\) 个跟踪轨迹之间的**最小代价分配**问题:

\[ \min_{\mathbf{A}} \sum_{i=1}^{n} \sum_{j=1}^{m} A_{ij} \cdot c_{ij} \]

约束:每个检测结果最多分配给一个跟踪轨迹,反之亦然。

SORT 的 IoU 代价矩阵

\[ \text{IoU}(d, t) = \frac{\text{area}(d \cap t)}{\text{area}(d \cup t)} \]

DeepSORT 在代价矩阵中增加了**外观描述符**(余弦距离):

\[ c_{ij} = \lambda \cdot (1 - \text{IoU}_{ij}) + (1 - \lambda) \cdot (1 - s_{ij}) \]

其中 \(s_{ij}\) 是检测 \(d_i\) 和跟踪 \(t_j\) 的深度外观特征之间的余弦相似度。

4.3 完整 Python 代码 — SORT 跟踪器

"""
第二层:SORT 跟踪器(简单在线实时跟踪)
========================================
基于检测的多目标跟踪,使用卡尔曼滤波器 + 匈牙利分配。
使用 YOLO 进行检测(通过 ultralytics),使用 IoU 进行关联。

模块:
1. KalmanFilter — 2D 图像空间的恒定速度模型
2. Track — 管理单个跟踪状态和历史
3. Sort — 管理所有跟踪轨迹并执行分配
"""

import numpy as np
import cv2
from scipy.optimize import linear_sum_assignment
from collections import deque


# ─── 卡尔曼滤波器 ────────────────────────────────────────────────

class KalmanFilter:
    """
    用于 2D 边界框跟踪的恒定速度卡尔曼滤波器。

    状态: [x, y, w, h, vx, vy]
    观测: [x, y, w, h]
    """

    def __init__(self, dt: float = 1.0):
        # 状态转移矩阵(恒定速度模型)
        self.F = np.array([
            [1, 0, 0, 0, dt, 0],
            [0, 1, 0, 0, 0,  dt],
            [0, 0, 1, 0, 0,  0],
            [0, 0, 0, 1, 0,  0],
            [0, 0, 0, 0, 1,  0],
            [0, 0, 0, 0, 0,  1],
        ], dtype=np.float32)

        # 观测矩阵(观测 x, y, w, h)
        self.H = np.array([
            [1, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0],
            [0, 0, 0, 1, 0, 0],
        ], dtype=np.float32)

        # 过程噪声协方差
        self.Q = np.eye(6, dtype=np.float32) * 0.01

        # 观测噪声协方差
        self.R = np.eye(4, dtype=np.float32) * 0.1

        # 状态和协方差
        self.x = np.zeros((6, 1), dtype=np.float32)  # 状态估计
        self.P = np.eye(6, dtype=np.float32)          # 误差协方差

    def init(self, measurement: np.ndarray) -> None:
        """用第一次观测初始化滤波器 [x, y, w, h]。"""
        self.x[:4] = measurement.reshape(4, 1)
        self.x[4:] = 0.0  # 速度 = 0

    def predict(self) -> tuple:
        """预测下一状态并返回投影后的观测值。"""
        self.x = self.F @ self.x
        self.P = self.F @ self.P @ self.F.T + self.Q

        # 投影到观测空间
        z_pred = self.H @ self.x
        return z_pred.flatten()

    def update(self, measurement: np.ndarray) -> tuple:
        """用新观测更新状态,返回投影后的观测值。"""
        z = measurement.reshape(4, 1).astype(np.float32)
        z_pred = self.H @ self.x

        # 创新(观测残差)
        y = z - z_pred

        # 卡尔曼增益
        S = self.H @ self.P @ self.H.T + self.R
        K = self.P @ self.H.T @ np.linalg.inv(S)

        # 更新状态和协方差
        self.x = self.x + K @ y
        self.P = (np.eye(6) - K @ self.H) @ self.P

        return z_pred.flatten()


# ─── 跟踪轨迹 ────────────────────────────────────────────────────

class Track:
    """单个跟踪轨迹。"""

    _next_id = 1

    def __init__(self, bbox: np.ndarray, frame_id: int):
        self.id = Track._next_id
        Track._next_id += 1

        self.bbox = bbox          # [x, y, w, h]
        self.age = 1              # 首次检测后的帧数
        self.hits = 1             # 总匹配次数
        self.time_since_update = 0

        # 用于预测的卡尔曼滤波器
        self.kf = KalmanFilter()
        self.kf.init(bbox)

        # 轨迹历史(用于可视化)
        self.trace = deque(maxlen=30)

    def predict(self) -> np.ndarray:
        """使用卡尔曼滤波器预测下一边界框。"""
        self.bbox = self.kf.predict()
        self.time_since_update += 1
        return self.bbox

    def update(self, bbox: np.ndarray) -> None:
        """用新检测更新轨迹。"""
        self.bbox = self.kf.update(bbox)
        self.hits += 1
        self.time_since_update = 0

    def state(self) -> str:
        """返回轨迹状态:hits >= 3 为 'confirmed',否则为 'tentative'。"""
        return 'confirmed' if self.hits >= 3 else 'tentative'


# ─── IoU 代价矩阵 ────────────────────────────────────────────────

def compute_iou(boxA: np.ndarray, boxB: np.ndarray) -> float:
    """
    计算两个边界框的交并比(IoU)。

    Parameters
    ----------
    boxA, boxB : [x, y, w, h]

    Returns
    -------
    iou : float,范围 [0, 1]
    """
    xA, yA, wA, hA = boxA
    xB, yB, wB, hB = boxB

    xi1 = max(xA, xB)
    yi1 = max(yA, yB)
    xi2 = min(xA + wA, xB + wB)
    yi2 = min(yA + hA, yB + hB)

    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    union_area = wA * hA + wB * hB - inter_area

    return inter_area / (union_area + 1e-6)


def compute_iou_matrix(detections: list, tracks: list) -> np.ndarray:
    """构建 IoU 代价矩阵:行=检测结果,列=跟踪轨迹。"""
    n, m = len(detections), len(tracks)
    matrix = np.zeros((n, m))
    for i, det in enumerate(detections):
        for j, trk in enumerate(tracks):
            matrix[i, j] = 1.0 - compute_iou(det, trk)  # 代价 = 1 - IoU
    return matrix


# ─── SORT 跟踪器 ─────────────────────────────────────────────────

class SORT:
    """
    SORT 多目标跟踪器。

    Parameters
    ----------
    max_age : int
        跟踪未更新时的最大帧数,超过后删除
    min_hits : int
        确认跟踪所需的最少检测次数
    iou_threshold : float
        匹配的最小 IoU(通常 0.3)
    """

    def __init__(self, max_age: int = 30, min_hits: int = 3, iou_threshold: float = 0.3):
        self.max_age = max_age
        self.min_hits = min_hits
        self.iou_threshold = iou_threshold
        self.tracks: list[Track] = []
        self.frame_count = 0

    def update(self, detections: np.ndarray, frame_id: int) -> np.ndarray:
        """
        用当前帧的检测结果更新跟踪器。

        Parameters
        ----------
        detections : (N, 4) 数组 — 每个检测 [x, y, w, h]

        Returns
        -------
        tracked_bboxes : (M, 5) 数组 — [x, y, w, h, track_id],仅确认的轨迹
        """
        self.frame_count += 1

        # ── 1. 预测步:推进所有轨迹 ──────────────────
        for track in self.tracks:
            track.predict()

        # ── 2. 匈牙利分配 ──────────────────────────────
        if len(detections) == 0:
            detected = []
        else:
            matched, unmatched_dets, unmatched_trks = self._associate(detections)

            # 更新已匹配的轨迹
            for d_idx, t_idx in matched:
                det = detections[d_idx]
                self.tracks[t_idx].update(det)

            # 为未匹配的检测创建新轨迹
            for d_idx in unmatched_dets:
                det = detections[d_idx]
                self.tracks.append(Track(det, frame_id))

        # ── 3. 删除过期轨迹 ─────────────────────────────────
        self.tracks = [
            t for t in self.tracks
            if t.time_since_update <= self.max_age and t.hits >= self.min_hits
        ]

        # ── 4. 输出已确认的轨迹 ───────────────────────────
        result = []
        for track in self.tracks:
            if track.state() == 'confirmed':
                x, y, w, h = [int(v) for v in track.bbox]
                result.append([x, y, w, h, track.id])
                track.trace.append((x + w//2, y + h//2))

        return np.array(result, dtype=np.int32) if result else np.empty((0, 5))

    def _associate(self, detections: np.ndarray) -> tuple:
        """基于 IoU 代价矩阵,用匈牙利算法匹配检测和轨迹。"""
        if len(self.tracks) == 0:
            return [], list(range(len(detections))), []

        cost_matrix = compute_iou_matrix(detections, [t.bbox for t in self.tracks])

        # 匈牙利算法(最小化代价)
        row_idx, col_idx = linear_sum_assignment(cost_matrix)

        matched = []
        unmatched_dets = list(range(len(detections)))
        unmatched_trks = list(range(len(self.tracks)))

        for r, c in zip(row_idx, col_idx):
            if cost_matrix[r, c] < (1.0 - self.iou_threshold):
                matched.append((r, c))
                if r in unmatched_dets:
                    unmatched_dets.remove(r)
                if c in unmatched_trks:
                    unmatched_trks.remove(c)

        return matched, unmatched_dets, unmatched_trks


# ─── SORT + YOLO 完整演示 ────────────────────────────────

def run_sort_with_yolo(video_source: int = 0, confidence: float = 0.3):
    """
    使用 YOLO 检测运行 SORT 多目标跟踪器。

    需要: pip install ultralytics
    """
    try:
        from ultralytics import YOLO
    except ImportError:
        print("[ERROR] 未安装 ultralytics。请运行: pip install ultralytics")
        return

    # 加载 YOLO 模型(YOLOv8n = nano,最快)
    print("[INFO] 加载 YOLOv8n...")
    model = YOLO("yolov8n.pt")

    cap = cv2.VideoCapture(video_source)
    if not cap.isOpened():
        print("[ERROR] 无法打开摄像头")
        return

    tracker = SORT(max_age=30, min_hits=3, iou_threshold=0.3)
    frame_id = 0

    # COCO 类别名称,用于人/车跟踪
    CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck']

    print("[INFO] SORT+DeepSORT 跟踪已启动。按 'q' 退出。")

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_id += 1

        # ── 检测 ──
        results = model(frame, verbose=False)[0]
        detections = []
        for box in results.boxes:
            cls_id = int(box.cls[0])
            conf = float(box.conf[0])
            if conf < confidence:
                continue
            if results.names[cls_id] not in CLASSES:
                continue
            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
            x, y = int(x1), int(y1)
            w, h = int(x2 - x1), int(y2 - y1)
            detections.append(np.array([x, y, w, h], dtype=np.float32))

        detections = np.array(detections) if detections else np.empty((0, 4))

        # ── 跟踪 ──
        tracked = tracker.update(detections, frame_id)

        # ── 可视化 ──
        for x, y, w, h, track_id in tracked:
            cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
            cv2.putText(frame, f"ID:{track_id}", (x, y - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

        cv2.putText(frame, f"Frame: {frame_id}  Tracks: {len(tracked)}",
                    (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
        cv2.imshow("SORT + YOLO", frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    run_sort_with_yolo(video_source=0)

4.4 DeepSORT — 外观描述符扩展

DeepSORT 通过添加**深度外观描述符**(ReID 网络)扩展 SORT,以处理 IoU 匹配失效的遮挡情况:

"""
DeepSORT 扩展 — 外观特征提取
============================
使用预训练的 ReID 网络提取外观描述符,
通过余弦距离匹配结合 IoU 进行关联。
"""

import torch
import torch.nn as nn
import numpy as np


class EmbeddingNetwork(nn.Module):
    """
    简单的 CNN 嵌入网络,用于 ReID。
    输出 128 维单位归一化特征向量。
    生产环境建议使用在 Market-1501 上预训练的 OSNet 或 ResNet-50。
    """

    def __init__(self, input_dim: int = 512, embedding_dim: int = 128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Linear(128, embedding_dim)
        self.bn = nn.BatchNorm1d(embedding_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bn(x)
        # L2 归一化,用于余弦相似度
        x = nn.functional.normalize(x, p=2, dim=1)
        return x


def cosine_distance(feat1: np.ndarray, feat2: np.ndarray) -> float:
    """两个特征向量之间的余弦距离。"""
    return 1.0 - np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2) + 1e-6)


# DeepSORT 综合使用余弦距离和 IoU:
# cost = lambda * (1 - IoU) + (1 - lambda) * (1 - cosine_sim)
# 典型 lambda = 0.98(优先外观而非 IoU)

5. 第三层 — 现代方法:基于 Transformer 的跟踪器

5.1 核心思想

Transformer 跟踪器用**注意力机制**取代了手工设计的特征和相关性滤波器,能够学习全局上下文。主流范式是**模板匹配**:模板(初始目标)和**搜索区域**(当前帧)被标记化(tokenize)后由 Transformer 编码器-解码器处理。

┌─────────────────────────────────────────────────────────────────┐
│              Transformer 跟踪器架构                              │
│                                                                  │
│   第 t-1 帧(模板)          第 t 帧(搜索)                       │
│   ┌─────────────────┐       ┌──────────────────┐              │
│   │  BBox: (x,y,w,h)│       │   全帧图像        │              │
│   └────────┬────────┘       └────────┬─────────┘              │
│            │                        │                         │
│            ▼                        ▼                         │
│   ┌─────────────────┐       ┌──────────────────┐              │
│   │ 模板 Patch       │       │ 搜索区域          │              │
│   │ (裁剪 + 缩放)    │       │ (3 倍模板大小)    │              │
│   └────────┬────────┘       └────────┬─────────┘              │
│            │                        │                         │
│            ▼                        ▼                         │
│   ┌──────────────────────────────────────────────┐             │
│   │              Token 嵌入                        │             │
│   │   [CLS] t1 t2 ... tk [SEP] s1 s2 ... sm      │             │
│   └────────┬────────────────────────────────────┘             │
│            │                                                     │
│            ▼                                                     │
│   ┌──────────────────────────────────────────────┐             │
│   │          Transformer 编码器 (SA + CA)          │             │
│   │  自注意力: 模板-模板                           │             │
│   │  交叉注意力: 模板 ↔ 搜索                       │             │
│   └────────┬────────────────────────────────────┘             │
│            │                                                     │
│            ▼                                                     │
│   ┌──────────────────────────────────────────────┐             │
│   │          Transformer 解码器                    │             │
│   │  Query = 可学习的 [QUERY] token               │             │
│   │  输出 = 边界框回归                             │             │
│   └──────────────────────────────────────────────┘             │
│            │                                                     │
│            ▼                                                     │
│        BBox: (x, y, w, h) + 置信度                              │
└─────────────────────────────────────────────────────────────────┘

核心创新:

  1. 交叉注意力:模板 token 关注搜索 token,学习哪些搜索特征对应于模板。这使得跟踪器能够处理相关滤波器无法处理的**外观变化**(形变、遮挡)。

  2. 全局上下文:与 KCF 的局部相关不同,Transformer 捕获搜索区域内所有像素之间的**长距离依赖关系**。

  3. 模板更新:大多数现代跟踪器(OSTrack、MixFormer)使用最近帧的特征在线更新模板,从而实现长期跟踪。

5.2 注意力机制

缩放点积注意力

\[ \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\!\left( \frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} \right) \mathbf{V} \]

其中: - \(\mathbf{Q}\)(Query)—— 我们要查找的内容 - \(\mathbf{K}\)(Key)—— 图像中包含的内容 - \(\mathbf{V}\)(Value)—— 实际的特征值 - \(d_k\) — Key 维度(用于缩放)

多头注意力

\[ \text{MHA}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \mathbf{W}^O \]

其中每个 \(\text{head}_i = \text{Attention}(\mathbf{Q}\mathbf{W}_i^Q, \mathbf{K}\mathbf{W}_i^K, \mathbf{V}\mathbf{W}_i^V)\)

5.3 完整 Python 代码 — 基于注意力的跟踪器

"""
第三层:基于 Transformer 的视觉目标跟踪器
==========================================
简化版 Transformer 跟踪器,演示了:
1. Token 嵌入(patch + 位置)
2. 自注意力和交叉注意力
3. 从 [CLS] token 回归边界框

生产环境建议使用 OSTrack、MixFormer 或 TransT。
"""

import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


# ─── 注意力层 ──────────────────────────────────────────────

class ScaledDotProductAttention(nn.Module):
    """缩放点积注意力:Q, K, V → 注意力分数。"""

    def __init__(self, d_k: int):
        super().__init__()
        self.d_k = d_k

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                mask: torch.Tensor = None) -> tuple:
        """
        Parameters
        ----------
        Q : (B, num_heads, seq_len, d_k)
        K : (B, num_heads, seq_len, d_k)
        V : (B, num_heads, seq_len, d_v)
        mask : optional, (B, 1, seq_len, seq_len)

        Returns
        -------
        output : (B, num_heads, seq_len, d_v)
        attn_weights : (B, num_heads, seq_len, seq_len)
        """
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights


class MultiHeadAttention(nn.Module):
    """多头注意力,带可学习的投影。"""

    def __init__(self, d_model: int = 256, num_heads: int = 8):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.d_model = d_model

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention(self.d_k)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                mask: torch.Tensor = None) -> tuple:
        B, seq_len, d_model = Q.shape

        # 投影并reshape为 (B, num_heads, seq_len, d_k)
        Q = self.W_Q(Q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)  # 广播到所有头

        out, attn = self.attention(Q, K, V, mask)

        # 拼接所有头并投影
        out = out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
        out = self.W_O(out)
        return out, attn


class FeedForward(nn.Module):
    """逐位置前馈网络:linear → GELU → linear。"""

    def __init__(self, d_model: int = 256, d_ff: int = 1024):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(F.gelu(self.fc1(x)))


class TransformerEncoderLayer(nn.Module):
    """单层 Transformer 编码器:自注意力 + FFN。"""

    def __init__(self, d_model: int = 256, num_heads: int = 8, d_ff: int = 1024):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # 带残差连接的自注意力
        attn_out, _ = self.mha(x, x, x, mask)
        x = self.norm1(x + attn_out)

        # 带残差连接的 FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x


class CrossAttentionLayer(nn.Module):
    """交叉注意力:Query 来自一个模态,Key/Value 来自另一个。"""

    def __init__(self, d_model: int = 256, num_heads: int = 8, d_ff: int = 1024):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, query: torch.Tensor, key_value: torch.Tensor,
                mask: torch.Tensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        query : (B, len_q, d_model)
        key_value : (B, len_kv, d_model)
        """
        attn_out, _ = self.mha(query, key_value, key_value, mask)
        x = self.norm1(query + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x


# ─── Patch 嵌入 ──────────────────────────────────────────────

class PatchEmbedding(nn.Module):
    """通过卷积将图像 patch 转换为 token 序列。"""

    def __init__(self, patch_size: int = 16, embed_dim: int = 256, in_channels: int = 3):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size,
                               stride=patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x : (B, C, H, W)
        Returns: (B, num_patches, embed_dim)
        """
        x = self.proj(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        return x


class PositionalEmbedding(nn.Module):
    """Patch 的可学习位置编码。"""

    def __init__(self, num_patches: int, embed_dim: int):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pos_embed[:, :x.size(1), :]


# ─── 简化的 Transformer 跟踪器 ─────────────────────────────

class SimpleTransformerTracker(nn.Module):
    """
    简化的 Transformer 跟踪器,支持模板 + 搜索 tokenization。
    """

    def __init__(self, embed_dim: int = 256, num_heads: int = 8,
                 num_layers: int = 3, patch_size: int = 16):
        super().__init__()
        self.embed_dim = embed_dim
        self.patch_size = patch_size

        # 模板和搜索的特征提取器
        self.patch_embed = PatchEmbedding(patch_size, embed_dim, in_channels=3)
        num_patches = (128 // patch_size) ** 2  # 128x128 搜索区域
        self.pos_embed = PositionalEmbedding(num_patches * 2 + 2, embed_dim)

        # 可学习的 [CLS]、[TPL] 和 [SRC] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.template_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.search_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Transformer 层
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])
        self.cross_layers = nn.ModuleList([
            CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])

        # BBox 回归头:从 [CLS] token 预测
        self.bbox_head = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 4),  # [dx, dy, dw, dh] 相对偏移量
            nn.Sigmoid()
        )

        # 初始化权重
        nn.init.xavier_uniform_(self.cls_token)
        nn.init.xavier_uniform_(self.template_token)
        nn.init.xavier_uniform_(self.search_token)

    def tokenize(self, template_img: torch.Tensor, search_img: torch.Tensor) -> torch.Tensor:
        """
        将模板和搜索 patch 组合成带有 token 的单个序列。

        Returns: (B, seq_len, embed_dim)
        """
        # 嵌入 patch
        template_patches = self.patch_embed(template_img)   # (B, T, D)
        search_patches = self.patch_embed(search_img)       # (B, S, D)

        # 添加 [TPL] 和 [SRC] token
        B = template_patches.size(0)
        tpl_tokens = self.template_token.expand(B, -1, -1)
        src_tokens = self.search_token.expand(B, -1, -1)

        # 拼接:[CLS], [TPL], template_patches, [SEP], search_patches
        tokens = torch.cat([
            self.cls_token.expand(B, -1, -1),
            tpl_tokens,
            template_patches,
            src_tokens,
            search_patches,
        ], dim=1)

        return self.pos_embed(tokens)

    def forward(self, template_img: torch.Tensor, search_img: torch.Tensor):
        """
        前向传播。

        Parameters
        ----------
        template_img : (B, 3, 64, 64) — 裁剪并调整大小后的模板
        search_img : (B, 3, 128, 128) — 搜索区域

        Returns
        -------
        bbox : (B, 4) — [cx_rel, cy_rel, w_rel, h_rel],范围 [0, 1]
        """
        tokens = self.tokenize(template_img, search_img)

        # 所有 token 之间的自注意力
        for layer in self.encoder_layers:
            tokens = layer(tokens)

        # 交叉注意力:[CLS] 关注搜索区域
        for layer in self.cross_layers:
            cls_out = layer(tokens[:, :1], tokens)

        # 使用 [CLS] token 进行边界框回归
        cls_features = tokens[:, 0]  # (B, D)
        bbox = self.bbox_head(cls_features)

        return bbox


# ─── 最小演示(模拟) ────────────────────────────────────

def run_transformer_tracker_demo(video_source: int = 0):
    """
    在实时摄像头上演示简化的 Transformer 跟踪器。
    使用上面的 PyTorch 模型进行模拟跟踪(无实际训练)。
    实际使用请从官方仓库加载 OSTrack 权重。
    """
    print("[INFO] Transformer 跟踪器演示")
    print("[INFO] 生产环境请使用 OSTrack: pip install osrtrack")
    print("[INFO] 此演示展示架构和流程。")

    cap = cv2.VideoCapture(video_source)
    if not cap.isOpened():
        print("[ERROR] 无法打开摄像头")
        return

    # 加载模型(简化版,未预训练)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleTransformerTracker(embed_dim=128, num_heads=4, num_layers=2)
    model = model.to(device)
    model.eval()

    # 获取第一帧并选择 bbox
    ret, frame = cap.read()
    if not ret:
        print("[ERROR] 无法读取第一帧")
        return

    print("[INFO] 绘制一个边界框,然后按 ENTER")
    bbox = cv2.selectROI("选择目标", frame, fromCenter=False, showCrosshair=True)
    cv2.destroyWindow("选择目标")

    x, y, w, h = bbox
    template_size = 64
    search_size = 128

    def crop_patch(frame, cx, cy, sz):
        """裁剪并调整正方形 patch 的大小。"""
        half = sz // 2
        x1, y1 = max(0, int(cx) - half), max(0, int(cy) - half)
        x2, y2 = min(frame.shape[1], int(cx) + half), min(frame.shape[0], int(cy) + half)
        patch = frame[y1:y2, x1:x2]
        if patch.size == 0:
            return np.zeros((sz, sz, 3), dtype=np.uint8)
        return cv2.resize(patch, (sz, sz))

    template = crop_patch(frame, x + w/2, y + h/2, template_size)
    search_region = crop_patch(frame, x + w/2, y + h/2, search_size)

    print("[INFO] 跟踪已启动。按 'q' 退出。")

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        tpl_tensor = torch.from_numpy(template).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
        src_tensor = torch.from_numpy(search_region).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0

        with torch.no_grad():
            output = model(tpl_tensor, src_tensor)

        dx, dy, dw, dh = output[0].cpu().numpy()
        cx, cy = x + w/2 + (dx - 0.5) * search_size, y + h/2 + (dy - 0.5) * search_size
        new_w, new_h = w * (1 + dw), h * (1 + dh)

        # 更新模板(在线更新)
        template = crop_patch(frame, cx, cy, template_size)
        search_region = crop_patch(frame, cx, cy, search_size)

        # 绘制结果
        ix, iy, iw, ih = int(cx - new_w/2), int(cy - new_h/2), int(new_w), int(new_h)
        cv2.rectangle(frame, (ix, iy), (ix+iw, iy+ih), (255, 0, 0), 2)
        cv2.putText(frame, "Transformer Tracker", (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
        cv2.imshow("Transformer Tracker", frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    run_transformer_tracker_demo(video_source=0)

6. 评估指标

6.1 MOTA 与 MOTP

MOTA(多目标跟踪精度):

\[ \text{MOTA} = 1 - \frac{\sum_t (m_t + f_m + g_t)}{\sum_t g_t} \in (-\infty, 1] \]

其中: - \(m_t\) — 第 \(t\) 帧的漏检(假阴性) - \(f_m\) — 第 \(t\) 帧的误检(假阳性) - \(g_t\) — 第 \(t\) 帧的真实目标总数

MOTP(多目标跟踪精度):

\[ \text{MOTP} = \frac{\sum_t \sum_i d_t^i}{\sum_t c_t} \]

其中 \(d_t^i\) 是第 \(t\) 帧第 \(i\) 个匹配检测的边界框重叠(IoU),\(c_t\) 是匹配数量。

6.2 ID 切换与碎片化

指标 描述
ID Switches 跟踪分配 ID 发生改变的次数
Fragments 跟踪被中断后恢复的次数
MT (Mostly Tracked) 覆盖率 ≥ 80% 真实轨迹的比例
ML (Mostly Lost) 覆盖率 ≤ 20% 真实轨迹的比例

6.3 实现代码

"""
MOT 评估工具
============
计算跟踪结果与真值之间的 MOTA、MOTP 及相关指标。
"""

import numpy as np


def compute_iou(boxA: np.ndarray, boxB: np.ndarray) -> float:
    """两个 [x, y, w, h] 边界框之间的 IoU。"""
    xA, yA, wA, hA = boxA
    xB, yB, wB, hB = boxB
    xi1, yi1 = max(xA, xB), max(yA, yB)
    xi2, yi2 = min(xA + wA, xB + wB), min(yA + hA, yB + hB)
    inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    union = wA * hA + wB * hB - inter
    return inter / (union + 1e-6)


def evaluate_tracking(tracked: list, ground_truth: list) -> dict:
    """
    计算 MOT 指标。

    Parameters
    ----------
    tracked : 每帧的字典列表:[{"id": int, "bbox": [x,y,w,h]}, ...]
    ground_truth : 同上格式

    Returns
    -------
    metrics : 包含 MOTA、MOTP、IDSw、IDs 等的字典
    """
    total_misses = 0
    total_false_positives = 0
    total_switches = 0
    total_gt = 0
    total_overlap = 0.0
    total_matched = 0

    for frame_tracked, frame_gt in zip(tracked, ground_truth):
        gt_ids = {d["id"] for d in frame_gt}
        trk_ids = {d["id"] for d in frame_tracked}

        # 按 ID 匹配(简化版 — 假设 ID 已对齐)
        for gt_d in frame_gt:
            best_iou, best_trk = 0.0, None
            for trk_d in frame_tracked:
                if trk_d["id"] != gt_d["id"]:
                    continue
                iou = compute_iou(gt_d["bbox"], trk_d["bbox"])
                if iou > best_iou:
                    best_iou, best_trk = iou, trk_d

            if best_trk is not None:
                total_overlap += best_iou
                total_matched += 1
            else:
                total_misses += 1

        for trk_d in frame_tracked:
            if trk_d["id"] not in {d["id"] for d in frame_gt}:
                total_false_positives += 1

        total_gt += len(frame_gt)

    mota = 1.0 - (total_misses + total_false_positives + total_switches) / max(total_gt, 1)
    motp = total_overlap / max(total_matched, 1)

    return {
        "MOTA": mota,
        "MOTP": motp,
        "ID Sw": total_switches,
        "Total GT": total_gt,
        "Missed": total_misses,
        "False Positives": total_false_positives,
    }

7. 分步实施指南

第一阶段 — 第一层:KCF 跟踪器(第 1 周)

  1. 从 OpenCV 内置 cv2.TrackerKCF_create() 开始——验证其能达到 200+ FPS
  2. 从头实现 HOG 特征提取器(3.3 节)
  3. 实现循环矩阵训练步骤(使用 FFT)
  4. 在视频序列上测试(常用 OTB-50 数据集)
  5. 与 OpenCV 内置跟踪器对比验证正确性

第二阶段 — 第二层:SORT 跟踪器(第 2 周)

  1. 实现卡尔曼滤波器(4.3 节)——用简单 1D 跟踪示例验证
  2. 通过 scipy.optimize.linear_sum_assignment 实现匈牙利算法
  3. 将 YOLO 检测与 ultralytics 集成
  4. 在视频上运行 SORT——调整 iou_threshold 并统计 ID 切换次数
  5. 扩展到 DeepSORT:添加外观特征提取(余弦距离匹配)
  6. 在 MOT17 数据集上用 MOTA/MOTP 评估

第三阶段 — 第三层:Transformer 跟踪器(第 3 周)

  1. 学习注意力机制:从头实现缩放点积注意力
  2. 从官方仓库运行 OSTrack 或 MixFormer(预训练权重可用)
  3. 分析交叉注意力图,了解模型关注的位置
  4. 实验不同的模板更新策略(固定模板 vs. 在线更新)
  5. 在 VOT2020/OTB100 上基准测试,并与第一层和第二层比较

8. 方法对比

指标 KCF 跟踪器 SORT DeepSORT Transformer 跟踪器
架构 相关滤波器 检测 + KF + 匈牙利 检测 + KF + 外观 + 匈牙利 注意力 + 模板匹配
单/多目标 单目标 多目标 多目标 单目标 / 多目标
速度(FPS) ~300 ~100 ~30-50 ~30
精度 ★★★☆☆ ★★★☆☆ ★★★★☆ ★★★★★
遮挡处理 一般(仅 IoU) 好(外观) 优秀
尺度变化 一般 优秀
需要 GPU 建议
需要训练 否(在线) 特征提取器 是(完整模型)
ID 持久性 中等 非常高
适用场景 简单场景,高 FPS 实时多目标 行人/车辆跟踪 复杂基准测试
复杂度 ⭐⭐ ⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐⭐

9. 扩展与变体

9.1 孪生网络

孪生跟踪器(SiamFC、SiamRPN、SiamMask)使用**双分支网络**: - 模板分支:编码初始目标 - 搜索分支:编码当前帧 - 相似度图:在特征空间通过互相关计算

SiamMask 将此扩展到**语义分割**,实现更高精度的目标定位。

9.2 全卷积跟踪器(FCNT)

FCNT 分析 CNN(VGG-16)的**激活图**,选择最具判别性的特征层进行跟踪,避免背景干扰。

9.3 长期跟踪

对于完全遮挡后需要**目标重检测**的长期跟踪:

  • ECO(高效卷积算子):使用分解卷积和训练样本选择,减少长期漂移
  • ATOM(基于重叠最大化的精确跟踪):学习直接估计边界框重叠
  • LaSOT 数据集:包含 1,400+ 视频的长期跟踪基准

9.4 实际应用

  • 自动驾驶:跨帧跟踪行人、车辆、骑行者
  • 安防监控:多摄像头行人重识别
  • 体育分析:足球/篮球中的球员和球跟踪
  • 医学成像:显微镜视频中的细胞和器官跟踪
  • 机器人学:视觉伺服中使用跟踪的参考点

10. 参考资料

  1. Henriques, J.F., et al. (2014). "High-Speed Tracking with Kernelized Correlation Filters." IEEE TPAMI, 37(3), 583–596. — KCF 算法论文
  2. Bolme, D.S., et al. (2010). "Visual Object Tracking using Adaptive Correlation Filters." CVPR. — MOSSE(KCF 前身)
  3. Bewley, A., et al. (2016). "Simple Online and Realtime Tracking." ICIP. — SORT 算法
  4. Wojke, N., et al. (2017). "Simple Online and Realtime Tracking with a Deep Association Metric." ICIP. — DeepSORT 扩展
  5. Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS. — Transformer 架构
  6. Chen, X., et al. (2022). "Unifying Short-Term and Long-Term Visual Tracking via Refinement." — OSTrack
  7. Yan, B., et al. (2021). "TransT: Transformer-Based Tracking." CVPR. — TransT 架构
  8. Zhang, Z., et al. (2022). "MixFormer: End-to-End Tracking with Iterative Mixed Attention." — MixFormer
  9. Müller, M., et al. (2018). "Tracking Attackers in UAV Videos with Adaptive Siamese Networks." — 孪生跟踪
  10. Kristan, M., et al. (2020). "The Eighth Visual Object Tracking VOT2020 Challenge Results." — VOT 基准
  11. Liang, C., et al. (2015). "Particle Filter-based Visual Tracking: A Survey." — 综合综述
  12. OpenCV Tracking API — docs.opencv.org
  13. SORT Implementation — GitHub: abewley/sort
  14. DeepSORT Implementation — GitHub: nwojke/deep_sort
  15. OSTrack — GitHub: osiam/ostrack
  16. MOT17 / MOT20 Benchmarks — motchallenge.net