基于keras的时域卷积网络(TCN)

1 前言

时域卷积网络(Temporal Convolutional Network,TCN)属于卷积神经网络(CNN)家族,于2017年被提出,目前已在多项时间序列数据任务中击败循环神经网络(RNN)家族。

img TCN 网络结构

图中,xi 表示第 i 个时刻的特征,可以是多维的。

TCN源码见-->GitHub - philipperemy/keras-tcn: Keras Temporal Convolutional Network.,由于源码过于复杂,新手不易上手,笔者参照源码,手撕了个简洁版的TCN,与君共享。

本文以 MNIST 手写数字分类为例,讲解 TCN 模型。关于 MNIST 数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

img

代码资源见-->时域卷积网络(TCN)案例模型

2 实验

TCN.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Conv1D,Activation,Flatten,Dense

#载入数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,
    valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,
    test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y

#残差块
def ResBlock(x,filters,kernel_size,dilation_rate):
    r=Conv1D(filters,kernel_size,padding='same',dilation_rate=dilation_rate,activation='relu')(x) #第一卷积
    r=Conv1D(filters,kernel_size,padding='same',dilation_rate=dilation_rate)(r) #第二卷积
    if x.shape[-1]==filters:
        shortcut=x
    else:
        shortcut=Conv1D(filters,kernel_size,padding='same')(x)  #shortcut(捷径)
    o=add([r,shortcut])
    o=Activation('relu')(o)  #激活函数
    return o

#序列模型
def TCN(train_x,train_y,valid_x,valid_y,test_x,test_y):
    inputs=Input(shape=(28,28))
    x=ResBlock(inputs,filters=32,kernel_size=3,dilation_rate=1)
    x=ResBlock(x,filters=32,kernel_size=3,dilation_rate=2)
    x=ResBlock(x,filters=16,kernel_size=3,dilation_rate=4)
    x=Flatten()(x)
    x=Dense(10,activation='softmax')(x)
    model=Model(input=inputs,output=x)
    #查看网络结构
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=30,verbose=2,validation_data=(valid_x,valid_y))
    #评估模型
    pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
    print('test_loss:',pre[0],'- test_acc:',pre[1])
     
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
TCN(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 28, 28)       0                                            
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 28, 32)       2720        input_1[0][0]                    
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 28, 32)       3104        conv1d_1[0][0]                   
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 28, 32)       2720        input_1[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, 28, 32)       0           conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 28, 32)       0           add_1[0][0]                      
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, 28, 32)       3104        activation_1[0][0]               
__________________________________________________________________________________________________
conv1d_5 (Conv1D)               (None, 28, 32)       3104        conv1d_4[0][0]                   
__________________________________________________________________________________________________
add_2 (Add)                     (None, 28, 32)       0           conv1d_5[0][0]                   
                                                                 activation_1[0][0]               
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 28, 32)       0           add_2[0][0]                      
__________________________________________________________________________________________________
conv1d_6 (Conv1D)               (None, 28, 16)       1552        activation_2[0][0]               
__________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, 28, 16)       784         conv1d_6[0][0]                   
__________________________________________________________________________________________________
conv1d_8 (Conv1D)               (None, 28, 16)       1552        activation_2[0][0]               
__________________________________________________________________________________________________
add_3 (Add)                     (None, 28, 16)       0           conv1d_7[0][0]                   
                                                                 conv1d_8[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 28, 16)       0           add_3[0][0]                      
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 448)          0           activation_3[0][0]               
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 10)           4490        flatten_1[0][0]                  
==================================================================================================
Total params: 23,130
Trainable params: 23,130
Non-trainable params: 0

网络训练结果:

Epoch 28/30
 - 6s - loss: 0.0112 - acc: 0.9966 - val_loss: 0.0539 - val_acc: 0.9854
Epoch 29/30
 - 6s - loss: 0.0080 - acc: 0.9977 - val_loss: 0.0536 - val_acc: 0.9872
Epoch 30/30
 - 6s - loss: 0.0099 - acc: 0.9965 - val_loss: 0.0486 - val_acc: 0.9892
test_loss: 0.055041389787220396 - test_acc: 0.9855000048875808

可以看到,TCN模型的预测精度为 0.9855, 超越了 seq2seq模型案例分析 中 AttSeq2Seq 模型(0.9825)、基于keras的双层LSTM网络和双向LSTM网络 中 DoubleLSTM 模型(0.9789)和 BiLSTM 模型(0.9795)、基于keras的残差网络 中 ResNet 模型(0.9721)。

3 拓展延申

有时候,并不需要最后一层 TCN 输出序列的所有步,而只需要最后一层 TCN 输出序列的第一步或最后一步。这时候,需要借助 lambda 关键字定义 Lambda 层,取代 Flatten 层。如下:

from keras.layers import Lambda
......
x=ResBlock(x,filters=16,kernel_size=3,dilation_rate=4)
x=Lambda(lambda x: x[:,0,:])(x)  #此前是:x=Flatten()(x)
x=Dense(10,activation='softmax')(x)
......

lambda 关键字用于定义匿名函数,应用如下:

import numpy as np
f=lambda x: x*x+x+1
x=np.array([1,2,3])
y=f(x)
print(y)  #输出:[ 3  7 13]

​ 声明:本文转自基于keras的时域卷积网络(TCN)

posted @ 2023-03-19 12:06  little_fat_sheep  阅读(346)  评论(0编辑  收藏  举报