利用神经网络对脑电图(EEG)降噪------开源的、低成本、低功耗微处理器神经网络模型解决方案
具体的软硬件实现点击 http://mcu-ai.com/ MCU-AI技术网页_MCU-AI人工智能
这个示例展示了如何使用EEGdenoiseNet基准数据集[1]和深度学习回归去除脑电图(EEG)信号中的眼电图(EOG)噪声。EEGdenoiseNet数据集包含4514个干净的EEG片段和3400个眼部伪迹片段,这些片段可以用来合成带有真实干净EEG的噪声EEG片段
这个示例使用干净和受EOG污染的EEG信号来训练一个长短期记忆(LSTM)模型以去除EOG伪迹。首先,将在原始输入信号上训练模型。然后,引入短时傅里叶变换(STFT)层,使模型在原始输入上提取时频特征进行训练。逆STFT层从去噪的STFT重构结果。使用时频特征特别是在信噪比(SNR)较低时可以提高性能。
EEGdeniseNet数据集包含4514个干净的EEG片段和3400个EOG片段,可用于生成三个数据集,用于训练、验证和测试深度学习模型。所有信号段的采样率为256Hz。实例为MATLAB语言。
% Download the data
datasetZipFile = matlab.internal.examples.downloadSupportFile("SPT","data/EEGEOGDenoisingData.zip");
datasetFolder = fullfile(fileparts(datasetZipFile),"EEG_EOG_Denoising_Dataset");
if ~exist(datasetFolder,"dir")
unzip(datasetZipFile,fileparts(datasetZipFile));
end
下载数据后,datasetFolder中的位置包含两个MAT文件:
EEG_all_epochs.mat 干净EEG数据
EOG_all_epochs.mat (EOG)数据
将干净的EEG和EOG信号相结合,生成具有不同信噪比(SNR)的有噪声EEG数据与干净的EEG数据构成训练数据,并且分成训练、验证和测试数据集。
绘制有噪声EEG数据与干净的EEG数据
显然,传统的任何算法很难将EEG数据从噪声中滤出来。
定义神经网络结构,之所以选择长短期记忆(LSTM)架构,是因为它能够从时间序列中学习特征。
numFeatures = 1;
numHiddenUnits = 100;
layers = [
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits)
dropoutLayer(0.2)
fullyConnectedLayer(numFeatures)
];
设置训练参数
maxEpochs = 5;
miniBatchSize = 150;
options = trainingOptions("adam", ...
Metrics="rmse", ...
MaxEpochs=maxEpochs, ...
MiniBatchSize=miniBatchSize, ...
InitialLearnRate=0.005, ...
GradientThreshold=1, ...
Plots="training-progress", ...
Shuffle="every-epoch", ...
Verbose=false, ...
ValidationData=ds_Validate_T, ...
ValidationFrequency=100, ...
OutputNetwork="best-validation-loss");
模型执行的效果
提高深度学习模型性能的常用方法是使用输入信号数据的特征进行训练。这些特征提供了输入数据的表示,这使得网络更容易学习信号的最重要方面。
选择窗口长度为64个样本、重叠长度为63个样本的短时傅立叶变换(STFT)。这种转换将有效地创建33个复杂特征,每个特征的长度为449个样本。
winLength = 64;
overlapLength = 63;
data = preview(ds_Train_T);
plotSTFT(data,winLength,overlapLength)
定义神经网络
minLen=512; % signal length
numFeatures=66; % number of features
win=rectwin(winLength); % analysis window
layers = [
sequenceInputLayer(1,MinLength=minLen)
stftLayer(Window=win,OverlapLength=overlapLength,transform="realimag")
lstmLayer(numHiddenUnits)
dropoutLayer(0.2)
fullyConnectedLayer(numFeatures)
istftLayer(Window=win,OverlapLength=overlapLength)
];
训练网络
if trainingFlag == "Train networks"
stftNet = trainnet(ds_Train_T,layers,"mse",options);end
网络性能