卡尔曼滤波源码注释和调用示例

卡尔曼滤波源码注释和调用示例

flyfish

Python版本代码地址
C++版代码地址

主要用于分析代码,增加了中文注释

import numpy as np
import scipy.linalg

"""
0.95分位数的卡方分布表,N自由度(包含N=1到9的值)。
取自MATLAB/Octave的chi2inv函数,用作Mahalanobis门限。
"""
chi2inv95 = {
    1: 3.8415,
    2: 5.9915,
    3: 7.8147,
    4: 9.4877,
    5: 11.070,
    6: 12.592,
    7: 14.067,
    8: 15.507,
    9: 16.919
}

class KalmanFilter(object):
    """
    一个用于图像空间中跟踪边界框的简单卡尔曼滤波器。

    8维状态空间
        x, y, a, h, vx, vy, va, vh
    包含边界框中心位置 (x, y)、长宽比 a、高度 h 及其相应的速度。

    对象运动遵循恒定速度模型。边界框位置 (x, y, a, h) 被作为状态空间的直接观测值(线性观测模型)。
    """

    def __init__(self):
        ndim, dt = 4, 1.

        # 创建卡尔曼滤波器模型矩阵
        self._motion_mat = np.eye(2 * ndim, 2 * ndim)
        for i in range(ndim):
            self._motion_mat[i, ndim + i] = dt
        self._update_mat = np.eye(ndim, 2 * ndim)

        # 运动和观测不确定性相对于当前状态估计进行选择。这些权重控制模型中的不确定性量。这有点hacky。
        self._std_weight_position = 1. / 20
        self._std_weight_velocity = 1. / 160

    def initiate(self, measurement):
        """
        从未关联的测量创建跟踪。

        参数
        ----------
        measurement : ndarray
            边界框坐标 (x, y, a, h) 包含中心位置 (x, y)、长宽比 a 和高度 h。

        返回值
        -------
        (ndarray, ndarray)
            返回新跟踪的均值向量(8维)和协方差矩阵(8x8维)。
        """
        mean_pos = measurement
        mean_vel = np.zeros_like(mean_pos)
        mean = np.r_[mean_pos, mean_vel]

        std = [
            2 * self._std_weight_position * measurement[3],
            2 * self._std_weight_position * measurement[3],
            1e-2,
            2 * self._std_weight_position * measurement[3],
            10 * self._std_weight_velocity * measurement[3],
            10 * self._std_weight_velocity * measurement[3],
            1e-5,
            10 * self._std_weight_velocity * measurement[3]
        ]
        covariance = np.diag(np.square(std))
        return mean, covariance

    def predict(self, mean, covariance):
        """
        基于模型预测下一状态。

        参数
        ----------
        mean : ndarray
            当前状态的均值向量(8维)。
        covariance : ndarray
            当前状态的协方差矩阵(8x8维)。

        返回值
        -------
        (ndarray, ndarray)
            返回预测的均值向量和协方差矩阵。
        """
        std_pos = [
            self._std_weight_position * mean[3],
            self._std_weight_position * mean[3],
            1e-2,
            self._std_weight_position * mean[3]
        ]
        std_vel = [
            self._std_weight_velocity * mean[3],
            self._std_weight_velocity * mean[3],
            1e-5,
            self._std_weight_velocity * mean[3]
        ]
        motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))

        mean = np.dot(self._motion_mat, mean)
        covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov

        return mean, covariance

    def project(self, mean, covariance):
        """
        将状态分布(均值和协方差)投影到观测空间。

        参数
        ----------
        mean : ndarray
            状态分布的均值向量(8维)。
        covariance : ndarray
            状态分布的协方差矩阵(8x8维)。

        返回值
        -------
        (ndarray, ndarray)
            返回观测空间中的均值向量(4维)和协方差矩阵(4x4维)。
        """
        std = [
            self._std_weight_position * mean[3],
            self._std_weight_position * mean[3],
            1e-1,
            self._std_weight_position * mean[3]
        ]
        innovation_cov = np.diag(np.square(std))

        mean = np.dot(self._update_mat, mean)
        covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
        return mean, covariance + innovation_cov

    def update(self, mean, covariance, measurement):
        """
        使用观测值更新状态分布。

        参数
        ----------
        mean : ndarray
            先验状态分布的均值向量(8维)。
        covariance : ndarray
            先验状态分布的协方差矩阵(8x8维)。
        measurement : ndarray
            当前观测到的边界框坐标 (x, y, a, h)。

        返回值
        -------
        (ndarray, ndarray)
            更新后的状态分布的均值向量和协方差矩阵。
        """
        projected_mean, projected_cov = self.project(mean, covariance)

        chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
        kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
                                             np.dot(covariance, self._update_mat.T).T, check_finite=False).T

        innovation = measurement - projected_mean

        new_mean = mean + np.dot(innovation, kalman_gain.T)
        new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
        return new_mean, new_covariance

    def gating_distance(self, mean, covariance, measurements, only_position=False):
        """
        计算状态分布和观测值之间的门限距离。

        可从 `chi2inv95` 中获得合适的距离门限。如果 `only_position` 为 False,则卡方分布有4个自由度,否则为2个。

        参数
        ----------
        mean : ndarray
            状态分布的均值向量(8维)。
        covariance : ndarray
            状态分布的协方差矩阵(8x8维)。
        measurements : ndarray
            N×4维矩阵,包含N个观测值,每个观测值的格式为 (x, y, a, h),其中 (x, y) 为边界框中心位置,a 为长宽比,h 为高度。
        only_position : 可选[bool]
            如果为True,距离计算仅针对边界框中心位置。

        返回值
        -------
        ndarray
            返回长度为N的数组,其中第i个元素包含 (mean, covariance) 和 `measurements[i]` 之间的平方Mahalanobis距离。
        """
        mean, covariance = self.project(mean, covariance)
        if only_position:
            mean, covariance = mean[:2], covariance[:2, :2]
            measurements = measurements[:, :2]

        cholesky_factor = np.linalg.cholesky(covariance)
        d = measurements - mean
        z = scipy.linalg.solve_triangular(
            cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
        squared_maha = np.sum(z * z, axis=0)
        return squared_maha

调用示例1

import numpy as np
from kalman_filter_cn import KalmanFilter

class KalmanFilterTracker:
    def __init__(self, initial_measurement):
        self.kf = KalmanFilter()
        self.mean, self.covariance = self.kf.initiate(initial_measurement)
        self.history = [initial_measurement[:2]]  # 只记录位置 (x, y)

    def predict_and_update(self, measurement):
        self.mean, self.covariance = self.kf.predict(self.mean, self.covariance)
        self.mean, self.covariance = self.kf.update(self.mean, self.covariance, measurement)
        self.history.append(self.mean[:2])  # 只记录位置 (x, y)
        return self.mean, self.covariance

# 示例用法
initial_measurement = np.array([0, 0, 1, 1])
tracker = KalmanFilterTracker(initial_measurement)

measurements = [
    np.array([1, 1, 1, 1]),
    np.array([2, 2, 1, 1]),
    np.array([3, 3, 1, 1]),
    np.array([4, 4, 1, 1]),
    np.array([5, 5, 1, 1])
]

for measurement in measurements:
    tracker.predict_and_update(measurement)

print("History of positions:", tracker.history)
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

def animate_kalman_filter(history):
    fig, ax = plt.subplots()
    ax.set_xlim(0, 6)
    ax.set_ylim(0, 6)
    line, = ax.plot([], [], 'bo-', label='Kalman Filter')
    true_line, = ax.plot([], [], 'ro--', label='True Path')

    def init():
        line.set_data([], [])
        true_line.set_data([], [])
        return line, true_line

    def update(frame):
        x_data = [h[0] for h in history[:frame+1]]
        y_data = [h[1] for h in history[:frame+1]]
        line.set_data(x_data, y_data)

        true_x = [i for i in range(len(history))]
        true_y = [i for i in range(len(history))]
        true_line.set_data(true_x, true_y)
        return line, true_line

    ani = FuncAnimation(fig, update, frames=len(history), init_func=init, blit=True, repeat=True)
    ani.save('kalman_filter.gif', writer='imagemagick')
    plt.legend()
    plt.show()

animate_kalman_filter(tracker.history)

在这里插入图片描述

调用示例2

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from kalman_filter_cn import KalmanFilter
from typing import Tuple

class KalmanFilterTracker:
    def __init__(self, initial_measurement: np.ndarray) -> None:
        self.kf = KalmanFilter()
        self.mean, self.covariance = self.kf.initiate(initial_measurement)
        self.history = [initial_measurement[:2]]  # 只记录位置 (x, y)

    def predict_and_update(self, measurement: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        self.mean, self.covariance = self.kf.predict(self.mean, self.covariance)
        self.mean, self.covariance = self.kf.update(self.mean, self.covariance, measurement)
        self.history.append(self.mean[:2])  # 只记录位置 (x, y)
        return self.mean, self.covariance

class KalmanFilterAnimation:
    def __init__(self, tracker: KalmanFilterTracker, measurements: np.ndarray) -> None:
        self.tracker = tracker
        self.measurements = measurements

    def init(self):
        self.line.set_data([], [])
        self.true_line.set_data([], [])
        return self.line, self.true_line

    def update(self, frame):
        x_data = [h[0] for h in self.tracker.history[:frame+1]]
        y_data = [h[1] for h in self.tracker.history[:frame+1]]
        self.line.set_data(x_data, y_data)

        true_x = [m[0] for m in self.measurements[:frame+1]]
        true_y = [m[1] for m in self.measurements[:frame+1]]
        self.true_line.set_data(true_x, true_y)
        return self.line, self.true_line

    def animate(self) -> None:
        fig, ax = plt.subplots()
        ax.set_xlim(0, 10)
        ax.set_ylim(-1.5, 1.5)
        self.line, = ax.plot([], [], 'bo-', label='Kalman Filter')
        self.true_line, = ax.plot([], [], 'ro--', label='True Path')

        ani = FuncAnimation(fig, self.update, frames=len(self.tracker.history),
                            init_func=self.init, blit=True, repeat=True)
        ani.save('kalman_filter_curve.gif', writer='imagemagick')
        plt.legend()
        plt.show()

# 初始化卡尔曼滤波器
initial_measurement = np.array([0, 0, 1, 1])
tracker = KalmanFilterTracker(initial_measurement)

# 生成测量值,形成曲线轨迹(正弦波)
measurements = []
for t in np.linspace(0, 10, 100):
    x = t
    y = np.sin(t)
    measurements.append(np.array([x, y, 1, 1]))

# 更新卡尔曼滤波器
for measurement in measurements:
    tracker.predict_and_update(measurement)

# 创建动画并生成GIF
animation = KalmanFilterAnimation(tracker, measurements)
animation.animate()

请添加图片描述
请添加图片描述

如果要分析滤波器性能、调试滤波器以及可视化滤波器是非常有用的,那么可以这样做

class KalmanFilterTracker:
    def __init__(self, initial_measurement: np.ndarray) -> None:
        self.kf = KalmanFilter()
        self.mean, self.covariance = self.kf.initiate(initial_measurement)
        self.history = [initial_measurement[:2]]  # 只记录位置 (x, y)
        self.states = [self.mean]  # 存储历史状态均值
        self.covariances = [self.covariance]  # 存储历史协方差矩阵

    def predict_and_update(self, measurement: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        self.mean, self.covariance = self.kf.predict(self.mean, self.covariance)
        self.mean, self.covariance = self.kf.update(self.mean, self.covariance, measurement)
        self.history.append(self.mean[:2])  # 只记录位置 (x, y)
        self.states.append(self.mean)  # 存储历史状态均值
        self.covariances.append(self.covariance)  # 存储历史协方差矩阵
        return self.mean, self.covariance

记录历史值可以分析滤波器的性能,查找和修正可能的问题。对于可视化和演示目的,存储历史值可以让绘制出估计轨迹和实际轨迹,以便直观地比较和展示滤波效果。

如果只是单纯的用,在递归估计中,只需保持前一时刻的状态即可

class KalmanFilterTracker:
    def __init__(self, initial_measurement: np.ndarray) -> None:
        self.kf = KalmanFilter()
        self.mean, self.covariance = self.kf.initiate(initial_measurement)

    def predict_and_update(self, measurement: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        self.mean, self.covariance = self.kf.predict(self.mean, self.covariance)
        self.mean, self.covariance = self.kf.update(self.mean, self.covariance, measurement)
        return self.mean, self.covariance

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/713939.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

多源最短路径算法 -- 弗洛伊德(Floyd)算法

1. 简介 Floyd算法,全名为Floyd-Warshall算法,亦称弗洛伊德算法或佛洛依德算法,是一种用于寻找给定加权图中所有顶点对之间的最短路径的算法。这种算法以1978年图灵奖获得者、斯坦福大学计算机科学系教授罗伯特弗洛伊德的名字命名。 2. 核心思…

打造私密的通信工具,极空间搭建免费开源的电子邮件管理程序『Cypht』

打造私密的通信工具,极空间搭建免费开源的电子邮件管理程序『Cypht』 哈喽小伙伴门好,我是Stark-C~ 说起电子邮件大家都不陌生,哪怕是在当前微信或者QQ已经非常普遍的今天,电子邮件在我们很多人的工作中都充当了重要的通信工具。…

【星座运势】本周财运分析,巨蟹座财富潜力大开!

大家好!今天我们来谈谈巨蟹座本周的财富运势。经过调查和数据分析,我发现巨蟹座这周的财运潜力很大!接下来,我将用通俗易懂的语言,通过代码说明,向大家展示巨蟹座的财富运势。 首先,我们需要通…

多设备互通、开箱即用的私有化笔记软件,极空间部署最强备忘录项目『Memos』

多设备互通、开箱即用的私有化笔记软件,极空间部署最强备忘录项目『Memos』 哈喽小伙伴们好,我是Stark-C~ 手机上的备忘录我想绝大多数的小伙伴都会用到,日常用来记录一下生活中的消费开支清单,或者工作中记录一些重要的任务或项…

【动态规划】0-1背包问题

【动态规划】0-1背包问题 题目:现在有四个物品,背包总容量为8,背包最多能装入价值为多少的物品? 我的图解 表格a【i】【j】表示的是容量为j的背包装入前i个物品的最大价值。 拿a【1】【1】来说,它的值就是背包容量为1,只考虑…

4.1 初探Spring Boot

初探Spring Boot实战概述 Spring Boot简介 Spring Boot是一个开源的Java框架,由Pivotal团队(现为VMware的一部分)开发,旨在简化Spring应用程序的创建和部署过程。它通过提供一系列自动化配置、独立运行的特性和微服务支持&#…

低代码开发MES系统,一周实现数字化

随着工业4.0和智能制造的兴起,企业对于生产过程的数字化、智能化需求日益迫切。制造执行系统(MES)作为连接计划层与控制层的关键信息系统,在提升生产效率、优化资源配置、保障产品质量等方面发挥着重要作用。然而,传统…

数据质量管理解决方案(55页PPT)

方案介绍: 数据质量管理解决方案是一个系统性的方法,旨在确保数据的准确性、完整性、一致性、可靠性和可用性。该解决方案覆盖了数据从产生到消亡的整个生命周期,包括数据的计划、获取、存储、共享、维护、应用和消亡等各个阶段。数据质量管…

IDEA导入项目报错java程序包不存在

如图文件结构,本来是在web-demo中操作,但是想导入一下其他模块,切换了项目文件的目录,发现需要重新对Tomcat等进行配置,配置好之后发现运行出现Java相关错误(如下)记录一下修正过程。 java: 程序…

【教程】Linux设置进程的优先级

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ 关键指令 sudo chrt -f <优先级> <指令> 示例脚本 当然也可以不是启动Python脚本&#xff0c;普通的指令都可以&#xff0c;可自行适当修…

2024/6/16 英语每日一段

Nature has the means--to a degree--to limit the effects of climate change. Intact ecosystems such as forests, grasslands, oceans and peatlands are “carbon sinks”--natural storage systems that remove atmospheric carbon and other greenhouse gases--and are …

Intel HDSLB 高性能四层负载均衡器 — 代码剖析和高级特性

目录 文章目录 目录前言代码剖析软件架构目录结构配置解析启动流程分析数据面 jobs 注册数据面 jobs 执行 转发流程分析收包阶段L2 处理阶段L3 处理阶段L4 处理阶段 高级特性大象流转发优化快慢路径分离转发优化报文基础转发优化 最后参考文档 前言 在前 2 篇文章中&#xff0…

【云原生】Kubernetes----Kubernetes集群部署Prometheus 和Grafana

目录 引言 一、环境准备 二、部署node-exporter &#xff08;一&#xff09;创建命名空间 &#xff08;二&#xff09;部署node-exporter 1.获取镜像 2.定义yaml文件 3.创建服务 4.查看监控数据 三、部署Prometheus &#xff08;一&#xff09;创建账号并授权 &…

Java学习笔记之基本数据类型转换

前言 本篇文章是基于我本人在初学JAVA阶段想记录的的学习笔记&#xff0c;如有错误&#xff0c;恳请指正。今天要干掉的是JAVA的基本数据类型转换 Java的基本数据类型转换 前言一&#xff0c;基本数据类型复习二&#xff0c;基本介绍什么是自动类型转换&#xff1f; 三&#…

【Numpy】一文向您详细介绍 np.round()

【Numpy】一文向您详细介绍 np.round() 下滑即可查看博客内容 &#x1f308; 欢迎莅临我的个人主页 &#x1f448;这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地&#xff01;&#x1f387; &#x1f393; 博主简介&#xff1a;985高校的普通本硕&#xff0c;…

从0到1搭建MCU芯片上操作系统环境。开发都需要哪些环节和准备

MCU芯片环境搭建与操作系统上载步骤 1. 硬件准备 选择合适的MCU芯片&#xff0c;例如STM32、GD32等。 准备开发板&#xff0c;用于硬件连接和实验。 准备必要的外围设备&#xff0c;如电源适配器、USB转串口模块等。 2. 软件环境搭建 安装编程语言环境&#xff0c;如C/C编译…

NVIDIA Triton系列02-功能与架构简介

NVIDIA Triton系列02-功能与架构简介 B站&#xff1a;肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com) 博客&#xff1a;肆十二-CSDN博客 问答&#xff1a;(10 封私信 / 72 条消息) 肆十二 - 知乎 (zhihu.com) 前面文章介绍微软 Teams 会议系统、微信软件与腾讯…

微信视频号视频怎么下载才能保存视频到手机相册,推荐一款稳定的视频号下载工具

视频号视频下载发现写了很多次&#xff0c;竟然还有很多人不知道微信视频号视频怎么下载&#xff0c;今天就来说说这款视频号下载工具。 视频号下载工具介绍 这款视频号下载工具叫视频号下载plus&#xff0c;也有很多人称之为视频下载小助手不知道的可以自行百度。 注意在百度…

码住!详解时序数据库不同分类与性能对比

加速发展中的时序数据库&#xff0c;基于不同架构&#xff0c;最流行的类别是&#xff1f; 作为管理工业场景时序数据的新兴数据库品类&#xff0c;时序数据库凭借着对海量时序数据的高效存储、高可扩展性、时序分析计算等特性&#xff0c;一跃成为物联网时代工业领域颇受欢迎的…

SolarLab - hackthebox

简介 靶机名称&#xff1a;SolarLab 难度&#xff1a;中等 靶场地址&#xff1a;https://app.hackthebox.com/machines/SolarLab 本地环境 靶机IP &#xff1a;10.10.11.16 ubuntu渗透机IP(ubuntu 22.04)&#xff1a;10.10.16.17 windows渗透机IP&#xff08;windows11&…