第二十节,使用RNN网络拟合回声信号序列
这一节使用TensorFlow中的函数搭建一个简单的RNN网络,使用一串随机的模拟数据作为原始信号,让RNN网络来拟合其对应的回声信号。
样本数据为一串随机的由0,1组成的数字,将其当成发射出去的一串信号。当碰到阻挡被反弹回来时,会收到原始信号的回声。
如果步长为3,那么输入和输出的序列如下图所示:
原序列 | 0 | 1 | 1 | 0 | 1 | 0 | 1 | 1 | 0 | 0 | 1 | 1 | 0 | 1 | 1 |
回声序列 | null | null | null | 0 | 1 | 1 | 0 | 1 | 0 | 1 | 1 | 0 | 0 | 1 | 1 |
如上表所示,回声序列的前三项是null,原序列的第一个信号为0,对应的是回声序列的第四项,即回声序列的每一个数都比原序列滞后3个时序。本例的任务就是把序列截取出来,对于每个原序列来预测它的回声序列。。
构建的网络结构如下图所示:
上图中,初始的输入有5个,xt个为t时刻输入序列值,另外4个为t-1时刻隐藏层的输出值ht-1。通过一层具有4个节点的RNN网络,再接一个全连接输出两个类别,分别表示输出0,和1类别的概率。这样每个序列都会有一个对应的预测分类值,最终将整个序列生成了预测序列。
下面我们会演示一个例子,这里随机生成一个具有50000个序列样本数据,然后根据原序列生成50000个回声序列样本数据。我们每个训练截取15个序列作为一个样本,我们设置小批量大小batch_size为5。
- 我们把50000个序列,转换为5x10000的数组。
- 对数组的每一行按长度为15进行分割,每一个小批量含有5x15个序列。
- 针对每一小批量的序列,我们使用RNN网络开始迭代,迭代每一个批次中的每一组序列(5x1)。
注意这里面的5就是我们设置的batch_size大小,这和我们之前在CNN以及DNN网络中的batch_size是一样的,即一次训练使用batch_size个样本。
下面是一个小批量的原序列数据和回声序列数据,这里回声序列的前三个序列值是无效的,这主要是与我们原序列切割方式有关的。
一 定义参数并生成样本数据
np.random.seed(0) ''' 一 定义参数生成样本数据 ''' num_epochs = 5 #迭代轮数 total_series_length = 50000 #序列样本数据长度 truncated_backprop_length = 15 #测试时截取数据长度 state_size = 4 #中间状态长度 num_classes = 2 #输出类别个数 echo_step = 3 #回声步长 batch_size = 5 #小批量大小 learning_rate = 0.4 #学习率 num_batches =total_series_length//batch_size//truncated_backprop_length #计算一轮可以分为多少批 def generate_date(): ''' 生成原序列和回声序列数据,回声序列滞后原序列echo_step个步长 返回原序列和回声序列组成的元组 ''' #生成原序列样本数据 random.choice()随机选取内容从0和1中选取total_series_length个数据,0,1数据的概率都是0.5 x = np.array(np.random.choice(2,total_series_length,p=[0.5,0.5])) #向右循环移位 如11110000->00011110 y =np.roll(x,echo_step) #回声序列,前echo_step个数据清0 y[0:echo_step] = 0 x = x.reshape((batch_size,-1)) #5x10000 #print(x) y = y.reshape((batch_size,-1)) #5x10000 #print(y) return (x,y)
二 定义占位符处理输入数据
定义三个占位符,batch_x为原始序列,batch_y为回声序列真实值,init_state为循环节点的初始值。batch_x是逐个输入网络的,所以需要将输进去的数据打散,按照时间序列变成15个数组,每个数组有batch_size个元素,进行统一批处理。
''' 二 定义占位符处理输入数据 ''' batch_x = tf.placeholder(dtype=tf.float32,shape=[batch_size,truncated_backprop_length]) #原始序列 batch_y = tf.placeholder(dtype=tf.int32,shape=[batch_size,truncated_backprop_length]) #回声序列 作为标签 init_state = tf.placeholder(dtype=tf.float32,shape=[batch_size,state_size]) #循环节点的初始状态值 #将batch_x沿axis = 1(列)的轴进行拆分 返回一个list 每个元素都是一个数组 [(5,),(5,)....] 一共15个元素,即15个序列 inputs_series = tf.unstack(batch_x,axis=1) labels_series = tf.unstack(batch_y,axis=1)
三 定义网络结构
定义一层循环与一层全网络连接。由于数据是一个二维数组序列,所以需要通过循环将输入数据按照原有序列逐个输入网络,并输出对应的predictions序列,同样的,对于每个序列值都要对其做loss计算,在loss计算使用了spare_softmax_cross_entropy_with_logits函数,因为label的最大值正好是1,而且是一位的,就不需要在使用one_hot编码了,最终将所有的loss均值放入优化器中。
''' 三 定义RNN网络结构 一个输入样本由15个输入序列组成 一个小批量包含5个输入样本 ''' current_state = init_state #存放当前的状态 predictions_series = [] #存放一个小批量中每个输入样本的预测序列值 每个元素为5x2 共有15个元素 losses = [] #存放一个小批量中每个输入样本训练的损失值 每个元素是一个标量,共有15个元素 #使用一个循环,按照序列逐个输入 for current_input,labels in zip(inputs_series,labels_series): #确定形状为batch_size x 1 current_input = tf.reshape(current_input,[batch_size,1]) ''' 加入初始状态 5 x 1序列值和 5 x 4中间状态 按列连接,得到 5 x 5数组 构成输入数据 ''' input_and_state_concatenated = tf.concat([current_input,current_state],1) #隐藏层激活函数选择tanh 5x4 next_state = tf.contrib.layers.fully_connected(input_and_state_concatenated,state_size,activation_fn = tf.tanh) current_state = next_state #输出层 激活函数选择None,即直接输出 5x2 logits = tf.contrib.layers.fully_connected(next_state,num_classes,activation_fn = None) #计算代价 loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits = logits)) losses.append(loss) #经过softmax计算预测值 5x2 注意这里并不是标签值 这里是one_hot编码 predictions = tf.nn.softmax(logits) predictions_series.append(predictions) total_loss = tf.reduce_mean(losses) train_step = tf.train.AdagradOptimizer(learning_rate).minimize(total_loss)
亲爱的读者和支持者们,自动博客加入了打赏功能,陆陆续续收到了各位老铁的打赏。在此,我想由衷地感谢每一位对我们博客的支持和打赏。你们的慷慨与支持,是我们前行的动力与源泉。
日期 | 姓名 | 金额 |
---|---|---|
2023-09-06 | *源 | 19 |
2023-09-11 | *朝科 | 88 |
2023-09-21 | *号 | 5 |
2023-09-16 | *真 | 60 |
2023-10-26 | *通 | 9.9 |
2023-11-04 | *慎 | 0.66 |
2023-11-24 | *恩 | 0.01 |
2023-12-30 | I*B | 1 |
2024-01-28 | *兴 | 20 |
2024-02-01 | QYing | 20 |
2024-02-11 | *督 | 6 |
2024-02-18 | 一*x | 1 |
2024-02-20 | c*l | 18.88 |
2024-01-01 | *I | 5 |
2024-04-08 | *程 | 150 |
2024-04-18 | *超 | 20 |
2024-04-26 | .*V | 30 |
2024-05-08 | D*W | 5 |
2024-05-29 | *辉 | 20 |
2024-05-30 | *雄 | 10 |
2024-06-08 | *: | 10 |
2024-06-23 | 小狮子 | 666 |
2024-06-28 | *s | 6.66 |
2024-06-29 | *炼 | 1 |
2024-06-30 | *! | 1 |
2024-07-08 | *方 | 20 |
2024-07-18 | A*1 | 6.66 |
2024-07-31 | *北 | 12 |
2024-08-13 | *基 | 1 |
2024-08-23 | n*s | 2 |
2024-09-02 | *源 | 50 |
2024-09-04 | *J | 2 |
2024-09-06 | *强 | 8.8 |
2024-09-09 | *波 | 1 |
2024-09-10 | *口 | 1 |
2024-09-10 | *波 | 1 |
2024-09-12 | *波 | 10 |
2024-09-18 | *明 | 1.68 |
2024-09-26 | B*h | 10 |
2024-09-30 | 岁 | 10 |
2024-10-02 | M*i | 1 |
2024-10-14 | *朋 | 10 |
2024-10-22 | *海 | 10 |
2024-10-23 | *南 | 10 |
2024-10-26 | *节 | 6.66 |
2024-10-27 | *o | 5 |
2024-10-28 | W*F | 6.66 |
2024-10-29 | R*n | 6.66 |
2024-11-02 | *球 | 6 |
2024-11-021 | *鑫 | 6.66 |
2024-11-25 | *沙 | 5 |
2024-11-29 | C*n | 2.88 |

【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了