LSTM自动编码器进行时间序列异常检测(Pytorch)
1|0环境准备
本次数据集的格式.arff
,需要用到arff2pandas
模块读取。
导入相关模块
2|0核心内容
本案例使用真实的心电图 (ECG) 数据来检测患者心跳的异常情况。我们将一起构建一个 LSTM 自动编码器,使用来自单个心脏病患者的真实心电图数据对其进行训练,并将在新的样本中,使用训练好的模型对其进行预测分类为正常或异常来来检测异常心跳。
本案例主要围绕以下几大核心展开。
2|1数据集
该数据集包含 5,000 个通过 ECG 获得的时间序列样本,样本一共具有 140 个时间步长。每个序列对应于一个患有充血性心力衰竭的患者的一次心跳。
我们有 5 种类型的心跳类别,他们分别是:
-
正常 (N)
-
室性早搏 (R-on-T PVC)
-
室性早搏 (PVC)
-
室上性早搏或异位搏动(SP 或 EB)
-
未分类的搏动 (UB)。
如果你的设备安装有 GPU,这将是非常好的,因为他的运行速度更快,可以节约你宝贵的时间。
数据读取
把训练和测试数据组合成一个单一的数据框。两者的加成,将为我们提供更多数据来训练我们的自动编码器。
看下数据集样貌。
我们有5000个例子。每一行代表一个心跳记录。我们重新命名所有的类。并将最后一列重命名为target
,这样在后面引用它将更为方便。
通过函数value_counts()
可以看看每个不同的心跳类分别有多少个样本。
当然,为了更加直观,我们通过可视化方法将心跳类别通过sns.countplot()
清晰展示出。
通过统计分析,我们发现普通类的样本最多。这个结果是非常理想的,也是意料之中的(异常检测中的异常往往是最少的),又因为我们需要使用这些正常类的数据来训练模型。
接下来,我们看一下每个类的平均时间序列(前面和后面做一个标准差平滑)。
首先定义一个辅助绘图函数。
根据上面的定义的辅助函数,循环绘制每个心跳类的平滑曲线。
根据上面五种心跳类的可视化结果看出,正常类具有与所有其他类明显不同的特征,这也许就是我们构建的模型能够检测出异常的关键所在。
3|0LSTM自动编码器
自编码器模型架构图解
自动编码器模型是一种神经网络,旨在以无监督的方式学习恒等函数以重建原始输入,同时在此过程中压缩数据,从而发现更有效和压缩的表示。
该网络可以看作由两部分组成:一个编码器函数 和一个生成重构的解码器
-
编码器网络:将原始的高维输入转换为潜在的低维代码。输入尺寸大于输出尺寸。
-
解码器网络:解码器网络从代码中恢复数据,输出层可能越来越大。
编码器网络本质上完成了降维,就像我们如何使用主成分分析(PCA)或矩阵分解(MF)一样。此外,自动编码器针对代码中的数据重构进行了显式优化。一个好的中间表示不仅可以捕获潜在变量,而且有利于完整的解压过程。
该模型包含由 ϕ 参数化的编码器函数 和由 θ 参数化的解码器函数 。在瓶颈层为输入x学习的低维代码为 ϕ,重构输入为 θϕ。
参数 (θ,ϕ) 一起学习以输出与原始输入相同的重构数据样本,θϕ,或者换句话说,学习恒等函数。有多种指标可以量化两个向量之间的差异,例如激活函数为 sigmoid 时的交叉熵,或者像 MSE 损失一样简单:
3|1数据预处理
获取所有正常的心跳并删除目标类的列。
合并所有其他类并将它们标记为异常。
将正常类样本分为训练集、验证集和测试集。
需要将样本转换为张量,使用它们来训练自动编码器。为此编写一个辅助函数来实现样本数据类型的转换,以便后续复用。
每个时间序列将被转换为形状 序列长度 x *特征数量 *的二维张量 。在我们的例子中为140x1的二维张量。
接下来将所有需要用到的数据集进行如上转换。
3|2构建 LSTM 自动编码器
自动编码器的工作是获取一些输入数据,将其通过模型传递,并获得输入的重构,重构应该尽可能匹配输入。
从某种意义上说,自动编码器试图只学习数据中最重要的特征,这里使用几个 LSTM 层(即LSTM Autoencoder)来捕获数据的时间依赖性。接下来我们一起看看如何将时间序列数据提供给自动编码器。
为了将序列分类为正常或异常,需要设定一个阈值,并规定高于该阈值时,心跳是异常的。
3|3重构损失
当训练一个自动编码器时,模型目标是尽可能地重构输入。这里的目标是通过最小化损失函数来实现的(就像在监督学习中一样)。这里所使用的损失函数被称为重构损失
。常用的重构损失是交叉熵损失和均方误差。
接下来将以GitHub[3]中的 LSTM Autoencoder为基础,并进行一些小调整。因为模型的工作是重建时间序列数据,因此该模型需要从编码器开始定义。
编码器使用两个LSTM层压缩时间序列数据输入。
接下来,我们将使用Decoder对压缩表示进行解码。
编码器和解码器均包含两个 LSTM 层和一个提供最终重建的输出层。
这里将所有内容包装成一个易于使用的模块了。
自动编码器类已经定义好,接下来创建一个它的实例。
自动编码器模型已经定义好。接下来需要训练模型。下面为训练过程编写一个辅助函数train_model
。
在每个epoch中,训练过程为模型提供所有训练样本,并评估验证集上的模型效果。注意,这里使用的批处理大小为1 ,即模型一次只能得到一个序列。另外还记录了过程中的训练和验证集损失。
值得注意的是,重构时做的是最小化L1损失,它测量的是 MAE(平均绝对误差),似乎比 MSE(均方误差)更好。
最后,我们将获得具有最小验证误差的模型,并使用该模型进行接下来的异常检测预。现在开始做一些训练。
3|4绘制模型损失
绘制模型在训练和测试数据集上面的损失曲线。
从可视化结果看出,我们所训练的模型收敛得很好。看起来我们可能需要一个更大的验证集来优化模型,但本文就不做展开了,现在就这样了。
保存模型:存储模型以备后用。模型保存是必须要做的,他是保存和避免我们宝贵工作不被浪费的重要步骤。
如果要下载和加载预训练模型,请取消注释下一行。
3|5设定阈值
有了训练好了的模型,可以看看训练集上的重构误差。同样编写一个辅助函数来使用模型预测结果。
该预测函数遍历数据集中的每个样本并记录预测结果和损失。
从图结果看,该阈值设定为26较为合适。
4|0模型评估
利用上面设定的阈值,我们可以将问题转化为一个简单的二分类任务:
-
如果一个例子的重构损失低于阈值,我们将其归类为"正常"心跳
-
或者,如果损失高于阈值,我们会将其归类为**"异常"**
4|1正常心跳
我们检查一下模型在正常心跳上的表现如何。这里使用新的测试集中的正常心跳。
计算下模型预测正确的样本有多少。
4|2异常心跳
我们对异常样本执行相同的操作,由于异常心跳和正常心跳的样本数量不一致,因此需要获得一个与正常心跳大小相同的子集,并对异常子集进行模型的预测。
最后计算高于阈值的样本数量,而这些样本将被视为异常心跳数据。
由此可见,我们得到了很好的结果。在现实项目中,可以根据要容忍的错误类型来调整阈值。在这种情况下,可能希望误报(正常心跳被视为异常)多于漏报(异常被视为正常)。
4|3样本对比观察
可以叠加真实的和重构的时间序列值,看看它们有多接近。得到相比的结果,可以针对一些正常和异常情况进行处理。
到目前为止,该实战案例已经告一段落了。在本案例中,我们一起学习了如何使用 PyTorch 创建 LSTM 自动编码器并使用它来检测 ECG 数据中的心跳异常。
4|4torch.stack() 详解
沿着一个新维度对输入张量序列进行连接。序列中所有的张量都应该为相同形状。
简而言之:把多个二维的张量凑成一个三维的张量;多个三维的凑成一个四维的张量…以此类推,也就是在增加新的维度进行堆叠。
参数:
-
inputs
(sequence of Tensors) - 待连接的张量序列。
注:python的序列数据只有list和tuple。函数中的输入inputs
只允许是序列;且序列内部的张量元素,必须shape
相等。 -
dim
(int) 新的维度, 必须在0到len(outputs)之间。注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。dim
是选择生成的维度,必须满足0<=dim<len(outputs)
;len(outputs)
是输出后的tensor
的维度大小。
例子
参考资料
[1]
来源: https://www.heartandstroke.ca/heart/tests/electrocardiogram
[2]
来源: https://en.wikipedia.org/wiki/Cardiac_cycle
[3]
GitHub: https://github.com/shobrook/sequitur
[4]
参考原文: https://curiousily.com/posts/time-series-anomaly-detection-using-lstm-autoencoder-with-pytorch-in-python/
[5]
Sequitur - Recurrent Autoencoder (RAE): https://github.com/shobrook/sequitur
[6]
Towards Never-Ending Learning from Time Series Streams: https://www.cs.ucr.edu/~eamonn/neverending.pdf
[7]
LSTM Autoencoder for Anomaly Detection: https://towardsdatascience.com/lstm-autoencoder-for-anomaly-detection-e1f4f2ee7ccf
__EOF__
作 者:清风紫雪
出 处:https://www.cnblogs.com/xiaofengzai/p/16243105.html
关于博主:编程路上的小学生,热爱技术,喜欢专研。评论和私信会在第一时间回复。或者直接私信我。
版权声明:署名 - 非商业性使用 - 禁止演绎,协议普通文本 | 协议法律文本。
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【推荐】一下。您的鼓励是博主的最大动力!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
2020-05-07 冲刺十四天(实现好友页的排序展示)